From f6317d80030d65707e868d355eb14e2d58f59637 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Sat, 12 Aug 2023 04:54:50 -0400 Subject: [PATCH] infra: enable compiled wheels for all supported Python (#201) --- .github/workflows/compile-pypi.yml | 13 +- ADDING_NEW_MODEL.md | 2 +- STYLE.md | 7 +- changelog.d/201.feature.md | 1 + hatch.toml | 12 +- pyproject.toml | 86 +++---- src/openllm/__init__.py | 56 +++-- src/openllm/_configuration.py | 100 ++------ src/openllm/_generation.py | 12 +- src/openllm/_llm.py | 114 +++------ src/openllm/_prompt.py | 13 +- src/openllm/_quantisation.py | 15 +- src/openllm/_schema.py | 19 +- src/openllm/_service.py | 43 +--- src/openllm/_strategies.py | 27 +- src/openllm/_types.py | 100 -------- src/openllm/_typing_compat.py | 102 ++++++++ src/openllm/bundle/__init__.py | 13 +- src/openllm/bundle/_package.py | 58 ++--- src/openllm/bundle/oci/__init__.py | 84 ++----- src/openllm/cli/_factory.py | 88 +++---- src/openllm/cli/_sdk.py | 27 +- src/openllm/cli/entrypoint.py | 107 +++----- src/openllm/cli/extension/list_models.py | 17 +- src/openllm/cli/extension/playground.py | 27 +- src/openllm/cli/termui.py | 15 +- src/openllm/client/__init__.py | 16 +- src/openllm/client/runtimes/__init__.py | 6 +- src/openllm/client/runtimes/base.py | 232 ++++++------------ src/openllm/client/runtimes/grpc.py | 127 +++++----- src/openllm/client/runtimes/http.py | 148 +++++------ src/openllm/exceptions.py | 9 - src/openllm/models/__init__.py | 4 +- src/openllm/models/auto/__init__.py | 8 +- src/openllm/models/auto/configuration_auto.py | 11 +- src/openllm/models/auto/factory.py | 21 +- src/openllm/models/auto/modeling_auto.py | 5 +- src/openllm/models/auto/modeling_flax_auto.py | 4 +- src/openllm/models/auto/modeling_tf_auto.py | 4 +- src/openllm/models/auto/modeling_vllm_auto.py | 4 +- src/openllm/models/baichuan/__init__.py | 1 - .../models/baichuan/configuration_baichuan.py | 1 - .../models/baichuan/modeling_baichuan.py | 7 +- .../models/baichuan/modeling_vllm_baichuan.py | 8 +- src/openllm/models/chatglm/__init__.py | 1 - .../models/chatglm/configuration_chatglm.py | 1 - .../models/chatglm/modeling_chatglm.py | 5 +- src/openllm/models/dolly_v2/__init__.py | 1 - .../models/dolly_v2/configuration_dolly_v2.py | 5 +- .../models/dolly_v2/modeling_dolly_v2.py | 178 +++++++------- .../models/dolly_v2/modeling_vllm_dolly_v2.py | 7 +- src/openllm/models/falcon/__init__.py | 1 - .../models/falcon/configuration_falcon.py | 1 - src/openllm/models/falcon/modeling_falcon.py | 10 +- src/openllm/models/flan_t5/__init__.py | 1 - .../models/flan_t5/configuration_flan_t5.py | 1 - .../models/flan_t5/modeling_flan_t5.py | 6 +- .../models/flan_t5/modeling_flax_flan_t5.py | 6 +- .../models/flan_t5/modeling_tf_flan_t5.py | 6 +- src/openllm/models/gpt_neox/__init__.py | 1 - .../models/gpt_neox/configuration_gpt_neox.py | 1 - .../models/gpt_neox/modeling_gpt_neox.py | 9 +- .../models/gpt_neox/modeling_vllm_gpt_neox.py | 8 +- src/openllm/models/llama/__init__.py | 4 +- .../models/llama/configuration_llama.py | 4 +- src/openllm/models/llama/modeling_llama.py | 13 +- .../models/llama/modeling_vllm_llama.py | 9 +- src/openllm/models/mpt/__init__.py | 1 - src/openllm/models/mpt/modeling_mpt.py | 11 +- src/openllm/models/mpt/modeling_vllm_mpt.py | 24 +- src/openllm/models/opt/__init__.py | 1 - src/openllm/models/opt/configuration_opt.py | 1 - src/openllm/models/opt/modeling_flax_opt.py | 8 +- src/openllm/models/opt/modeling_opt.py | 7 +- src/openllm/models/opt/modeling_tf_opt.py | 10 +- src/openllm/models/opt/modeling_vllm_opt.py | 7 +- src/openllm/models/stablelm/__init__.py | 4 +- .../models/stablelm/configuration_stablelm.py | 1 - .../models/stablelm/modeling_stablelm.py | 6 +- .../models/stablelm/modeling_vllm_stablelm.py | 20 +- src/openllm/models/starcoder/__init__.py | 1 - .../models/starcoder/modeling_starcoder.py | 11 +- .../starcoder/modeling_vllm_starcoder.py | 20 +- src/openllm/serialisation/__init__.py | 27 +- src/openllm/serialisation/constants.py | 9 +- src/openllm/serialisation/ggml.py | 24 +- .../serialisation/transformers/__init__.py | 27 +- .../serialisation/transformers/_helpers.py | 19 +- .../serialisation/transformers/weights.py | 9 +- src/openllm/testing.py | 27 +- src/openllm/utils/__init__.py | 37 +-- src/openllm/utils/analytics.py | 25 +- src/openllm/utils/codegen.py | 42 ++-- src/openllm/utils/dantic.py | 28 +-- src/openllm/utils/import_utils.py | 62 ++--- src/openllm/utils/lazy.py | 21 +- src/openllm/utils/representation.py | 69 ++---- tests/__init__.py | 13 - tests/_strategies/__init__.py | 13 - tests/_strategies/_configuration.py | 14 -- tests/client_test.py | 21 -- tests/configuration_test.py | 16 -- tests/conftest.py | 45 +--- tests/models/__init__.py | 13 - tests/models/conftest.py | 40 +-- tests/models/flan_t5_test.py | 14 -- tests/models/opt_test.py | 13 - tests/models_test.py | 14 -- tests/package_test.py | 14 -- tests/strategies_test.py | 14 -- tools/dependencies.py | 1 + tools/update-brew-tap.py | 12 +- tools/update-config-stubs.py | 33 ++- tools/update-init-import.py | 17 -- tools/update-models-import.py | 4 +- typings/attr/__init__.pyi | 29 ++- typings/attr/_cmp.pyi | 18 +- typings/click_option_group/_core.pyi | 9 +- typings/deepmerge/merger.pyi | 8 +- typings/deepmerge/strategy/core.pyi | 8 +- typings/nbformat/_struct.pyi | 9 +- typings/rsmiBindings.pyi | 47 ++-- typings/simple_di/__init__.pyi | 29 +++ typings/simple_di/providers.pyi | 54 ++++ 124 files changed, 1086 insertions(+), 2048 deletions(-) create mode 100644 changelog.d/201.feature.md delete mode 100644 src/openllm/_types.py create mode 100644 src/openllm/_typing_compat.py delete mode 100644 tests/client_test.py delete mode 100755 tools/update-init-import.py create mode 100644 typings/simple_di/__init__.pyi create mode 100644 typings/simple_di/providers.pyi diff --git a/.github/workflows/compile-pypi.yml b/.github/workflows/compile-pypi.yml index b85b07c7..0aa2b985 100644 --- a/.github/workflows/compile-pypi.yml +++ b/.github/workflows/compile-pypi.yml @@ -32,7 +32,7 @@ concurrency: cancel-in-progress: true jobs: pure-wheels-sdist: - name: Building pure wheels and sdist + name: Pure wheels and sdist distribution runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 @@ -50,31 +50,26 @@ jobs: path: dist/* if-no-files-found: error mypyc: - name: Building compiled mypyc wheels (${{ matrix.name }}) + name: Compiled mypyc wheels (${{ matrix.name }}) runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: include: - # NOTE: ubuntu x86 - os: ubuntu-latest name: linux-x86_64 - python: 311 # NOTE: darwin amd64 - os: macos-latest name: macos-x86_64 macos_arch: "x86_64" - python: 311 # NOTE: darwin arm64 - os: macos-latest name: macos-arm64 macos_arch: "arm64" - python: 311 # NOTE: darwin universal2 - os: macos-latest name: macos-universal2 macos_arch: "universal2" - python: 311 steps: - uses: actions/checkout@v3 with: @@ -85,9 +80,9 @@ jobs: with: python-version: 3.9 - name: Build wheels via cibuildwheel - uses: pypa/cibuildwheel@v2.14.1 + uses: pypa/cibuildwheel@v2.15.0 env: - CIBW_BUILD: cp${{ matrix.python }}-* + CIBW_BEFORE_BUILD_MACOS: "rustup target add aarch64-apple-darwin" CIBW_ARCHS_MACOS: "${{ matrix.macos_arch }}" MYPYPATH: /project/typings - name: Upload wheels as workflow artifacts diff --git a/ADDING_NEW_MODEL.md b/ADDING_NEW_MODEL.md index 63367ba1..c7a68a90 100644 --- a/ADDING_NEW_MODEL.md +++ b/ADDING_NEW_MODEL.md @@ -19,7 +19,7 @@ All the relevant code for incorporating a new model resides within `src/openllm/models/{model_name}/__init__.py` - [ ] Adjust the entrypoints for files at `src/openllm/models/auto/*` - [ ] Modify the main `__init__.py`: `src/openllm/models/__init__.py` -- [ ] Run the following script to update dummy objects: `hatch run update-dummy` +- [ ] Run the following to update stubs: `hatch run check-stubs` For a working example, check out any pre-implemented model. diff --git a/STYLE.md b/STYLE.md index ad1afe95..c1bd89cb 100644 --- a/STYLE.md +++ b/STYLE.md @@ -87,12 +87,11 @@ _If you have any suggestions, feel free to give it on our discord server!_ def foo(x): return rotate_cv(x) if x > 0 else -x ``` -- imports should be grouped by their types, and each import should be designated - on its own line. +- imports should be grouped by their types: standard library, third-party, and local ```python - import os - import sys + import os, sys + import orjson, bentoml ``` This is partially to make it easier to work with merge-conflicts, and easier for IDE to navigate context definition. diff --git a/changelog.d/201.feature.md b/changelog.d/201.feature.md new file mode 100644 index 00000000..71c8ac49 --- /dev/null +++ b/changelog.d/201.feature.md @@ -0,0 +1 @@ +Added all compiled wheels for all supported Python version for Linux and MacOS diff --git a/hatch.toml b/hatch.toml index fd511617..027abb15 100644 --- a/hatch.toml +++ b/hatch.toml @@ -75,25 +75,21 @@ dependencies = [ ] [envs.default.scripts] changelog = "towncrier build --version main --draft" -check-stubs = [ - "./tools/update-config-stubs.py", - "./tools/update-models-import.py", - "./tools/update-init-import.py", - "update-dummy", -] +check-stubs = ["./tools/update-config-stubs.py", "./tools/update-models-import.py", "update-dummy"] compile = "bash ./compile.sh {args}" -update-dummy = ["- ./tools/update-dummy.py", "./tools/update-dummy.py"] inplace-changelog = "towncrier build --version main --keep" quality = [ "./tools/dependencies.py", "./tools/update-readme.py", "- ./tools/update-brew-tap.py", "check-stubs", - "pre-commit run --all-files", + "- pre-commit run --all-files", ] recompile = ["bash ./clean.sh", "compile"] setup = "pre-commit install" +tool = ["quality", "recompile -nx"] typing = ["- pre-commit run mypy {args:-a}", "- pre-commit run pyright {args:-a}"] +update-dummy = ["- ./tools/update-dummy.py", "./tools/update-dummy.py"] [envs.tests] dependencies = [ # NOTE: interact with docker for container tests. diff --git a/pyproject.toml b/pyproject.toml index 7ddb6a0b..864723b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "httpx", "click>=8.1.3", "typing_extensions", + "mypy_extensions", "ghapi", "cuda-python;platform_system!=\"Darwin\"", "bitsandbytes<0.42", @@ -124,12 +125,12 @@ vllm = ["vllm", "ray"] [tool.cibuildwheel] build-verbosity = 1 -# So the following enviuronment will be targeted for compiled wheels: -# - Python: CPython 3.8+ only +# So the following environment will be targeted for compiled wheels: +# - Python: CPython 3.8-3.11 # - Architecture (64-bit only): amd64 / x86_64, universal2, and arm64 # - OS: Linux (no musl), and macOS build = "cp3*-*" -skip = ["*-manylinux_i686", "*-musllinux_*", "*-win32", "pp-*"] +skip = ["*-manylinux_i686", "*-musllinux_*", "*-win32", "pp-*", "cp312-*"] [tool.cibuildwheel.environment] HATCH_BUILD_HOOKS_ENABLE = "1" @@ -183,6 +184,9 @@ fail-under = 100 verbose = 2 whitelist-regex = ["test_.*"] +[tool.check-wheel-contents] +toplevel = ["openllm"] + [tool.pytest.ini_options] addopts = ["-rfEX", "-pno:warnings", "--snapshot-warn-unused"] python_files = ["test_*.py", "*_test.py"] @@ -194,7 +198,6 @@ extend-exclude = [ "examples", "src/openllm/playground", "src/openllm/__init__.py", - "src/openllm/_types.py", "src/openllm/_version.py", "src/openllm/utils/dummy_*.py", "src/openllm/models/__init__.py", @@ -233,7 +236,9 @@ ignore = [ "PLR0915", "PLR2004", # magic value to use constant "E501", # ignore line length violation + "E401", # ignore multiple line import "E702", + "I001", # unsorted imports "PYI021", # ignore docstring in stubs, as pyright will include docstring in stubs. "D103", # Just missing docstring for magic methods. "D102", @@ -245,13 +250,10 @@ ignore = [ "D105", # magic docstring "E701", # multiple statement on single line ] -line-length = 119 -target-version = "py312" -unfixable = [ - "F401", # Don't touch unused imports, just warn about it. - "TCH004", # Don't touch import outside of TYPE_CHECKING block - "RUF100", # unused noqa, just warn about it -] +line-length = 768 +target-version = "py38" +typing-modules = ["openllm._typing_compat"] +unfixable = ["TCH004"] [tool.ruff.flake8-type-checking] exempt-modules = ["typing", "typing_extensions", "."] runtime-evaluated-base-classes = [ @@ -260,7 +262,7 @@ runtime-evaluated-base-classes = [ "openllm._configuration.GenerationConfig", "openllm._configuration.ModelSettings", ] -runtime-evaluated-decorators = ["attrs.define", "attrs.frozen"] +runtime-evaluated-decorators = ["attrs.define", "attrs.frozen", "trait"] [tool.ruff.pydocstyle] convention = "google" [tool.ruff.pycodestyle] @@ -271,7 +273,7 @@ force-single-line = false force-wrap-aliases = true known-first-party = ["openllm", "bentoml"] known-third-party = ["transformers", "click", "huggingface_hub", "torch", "vllm", "auto_gptq"] -lines-after-imports = 1 +lines-after-imports = 0 lines-between-types = 0 no-lines-before = ["future", "standard-library"] relative-imports-order = "closest-to-furthest" @@ -279,12 +281,11 @@ required-imports = ["from __future__ import annotations"] [tool.ruff.flake8-quotes] avoid-escape = false [tool.ruff.extend-per-file-ignores] -"src/openllm/_service.py" = ["I001", "E401"] +"src/openllm/_service.py" = ["E401"] "src/openllm/cli/entrypoint.py" = ["D301"] -"src/openllm/models/**" = ["I001", "E", "D", "F"] -"src/openllm/utils/__init__.py" = ["I001"] -"src/openllm/utils/import_utils.py" = ["PLW0603"] "src/openllm/client/runtimes/*" = ["D107"] +"src/openllm/models/**" = ["E", "D", "F"] +"src/openllm/utils/import_utils.py" = ["PLW0603"] "tests/**/*" = [ "S101", "TID252", @@ -348,6 +349,7 @@ omit = [ "src/openllm/__init__.py", "src/openllm/__main__.py", "src/openllm/utils/dummy_*.py", + "src/openllm/_typing_compat.py", ] source_pkgs = ["openllm"] [tool.coverage.report] @@ -378,6 +380,7 @@ omit = [ "src/openllm/__init__.py", "src/openllm/__main__.py", "src/openllm/utils/dummy_*.py", + "src/openllm/_typing_compat.py", ] precision = 2 show_missing = true @@ -387,16 +390,17 @@ analysis.useLibraryCodeForTypes = true exclude = [ "__pypackages__/*", "src/openllm/playground/", + "src/openllm/models/", "src/openllm/__init__.py", "src/openllm/__main__.py", "src/openllm/utils/dummy_*.py", - "src/openllm/models", + "src/openllm/_typing_compat.py", "tools", "examples", "tests", ] include = ["src/openllm"] -pythonVersion = "3.12" +pythonVersion = "3.8" reportMissingImports = "warning" reportMissingTypeStubs = false reportPrivateUsage = "warning" @@ -408,13 +412,16 @@ reportWildcardImportFromLibrary = "warning" typeCheckingMode = "strict" [tool.mypy] -# TODO: Enable model for strict type checking -exclude = ["src/openllm/playground/", "src/openllm/utils/dummy_*.py", "src/openllm/models"] -local_partial_types = true +exclude = [ + "src/openllm/playground/", + "src/openllm/utils/dummy_*.py", + "src/openllm/models", + "src/openllm/_typing_compat.py", +] modules = ["openllm"] mypy_path = "typings" pretty = true -python_version = "3.11" +python_version = "3.8" show_error_codes = true warn_no_return = false warn_return_any = false @@ -430,6 +437,7 @@ module = [ "optimum.*", "inflection.*", "huggingface_hub.*", + "click_option_group.*", "peft.*", "auto_gptq.*", "vllm.*", @@ -437,13 +445,13 @@ module = [ "httpx.*", "cloudpickle.*", "circus.*", - "grpc_health.*", + "grpc_health.v1.*", "transformers.*", "ghapi.*", ] [[tool.mypy.overrides]] ignore_errors = true -module = ["openllm.models.*", "openllm._types", "openllm.playground.*"] +module = ["openllm.models.*", "openllm.playground.*", "openllm._typing_compat"] [tool.hatch.version] fallback-version = "0.0.0" @@ -476,24 +484,12 @@ dependencies = [ "types-protobuf", ] enable-by-default = false -exclude = [ - # no reason to compile model since framework is already written in C++ - "src/openllm/models/**", - "src/openllm/bundle/**", - "src/openllm/cli/**", - "src/openllm/playground/**", - "src/openllm/__main__.py", - "src/openllm/_types.py", - "src/openllm/_llm.py", - "src/openllm/_service.py", - "src/openllm/_configuration.py", - # can't compile serialisation for transformers since it will raise segfault - "src/openllm/serialisation/transformers", - "src/openllm/utils/analytics.py", -] include = [ - "src/openllm/serialisation", + "src/openllm/bundle", + "src/openllm/models/__init__.py", + "src/openllm/models/auto/__init__.py", "src/openllm/utils/__init__.py", + "src/openllm/utils/codegen.py", "src/openllm/__init__.py", "src/openllm/_prompt.py", "src/openllm/_schema.py", @@ -505,20 +501,18 @@ include = [ ] # NOTE: This is consistent with pyproject.toml mypy-args = [ - "--pretty", "--strict", # this is because all transient library doesn't have types "--allow-subclassing-any", - "--check-untyped-defs", "--follow-imports=skip", - "--python-version=3.11", + "--check-untyped-defs", "--ignore-missing-imports", "--no-warn-return-any", "--warn-unreachable", "--no-warn-no-return", "--no-warn-unused-ignores", "--exclude='/src\\/openllm\\/playground\\/**'", - "--exclude='/src\\/openllm\\/_types\\.py$'", + "--exclude='/src\\/openllm\\/_typing_compat\\.py$'", ] -options = { verbose = true, debug_level = "2", opt_level = "3" } +options = { verbose = true, strip_asserts = true, debug_level = "2", opt_level = "3", include_runtime_files = true } require-runtime-dependencies = true diff --git a/src/openllm/__init__.py b/src/openllm/__init__.py index 04dd8b47..aafd18e4 100644 --- a/src/openllm/__init__.py +++ b/src/openllm/__init__.py @@ -9,12 +9,8 @@ deploy, and monitor any LLMs with ease. * Native integration with BentoML and LangChain for custom LLM apps """ from __future__ import annotations -import logging as _logging -import os as _os -import typing as _t -import warnings as _warnings +import logging as _logging, os as _os, typing as _t, warnings as _warnings from pathlib import Path as _Path - from . import exceptions as exceptions, utils as utils if utils.DEBUG: @@ -35,7 +31,6 @@ _import_structure: dict[str, list[str]] = { "exceptions": [], "models": [], "client": [], "bundle": [], "playground": [], "testing": [], "utils": ["infer_auto_class"], "serialisation": ["ggml", "transformers"], "cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"], "_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable", "LLMEmbeddings"], "_configuration": ["LLMConfig", "GenerationConfig", "SamplingParams"], "_generation": ["StopSequenceCriteria", "StopOnTokens", "LogitsProcessorList", "StoppingCriteriaList", "prepare_logits_processor"], "_quantisation": ["infer_quantisation_config"], "_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput", "EmbeddingsOutput", "unmarshal_vllm_outputs", "HfAgentInput"], "models.auto": ["AutoConfig", "CONFIG_MAPPING", "MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"], "models.chatglm": ["ChatGLMConfig"], "models.baichuan": ["BaichuanConfig"], "models.dolly_v2": ["DollyV2Config"], "models.falcon": ["FalconConfig"], "models.flan_t5": ["FlanT5Config"], "models.gpt_neox": ["GPTNeoXConfig"], "models.llama": ["LlamaConfig"], "models.mpt": ["MPTConfig"], "models.opt": ["OPTConfig"], "models.stablelm": ["StableLMConfig"], "models.starcoder": ["StarCoderConfig"] } - COMPILED = _Path(__file__).suffix in (".pyd", ".so") if _t.TYPE_CHECKING: @@ -61,25 +56,38 @@ if _t.TYPE_CHECKING: from .serialisation import ggml as ggml, transformers as transformers from openllm.utils import infer_auto_class as infer_auto_class +try: + if not (utils.is_torch_available() and utils.is_cpm_kernels_available()): raise exceptions.MissingDependencyError +except exceptions.MissingDependencyError: + _import_structure["utils.dummy_pt_objects"] = ["ChatGLM", "Baichuan"] +else: + _import_structure["models.chatglm"].extend(["ChatGLM"]) + _import_structure["models.baichuan"].extend(["Baichuan"]) + if _t.TYPE_CHECKING: + from .models.baichuan import Baichuan as Baichuan + from .models.chatglm import ChatGLM as ChatGLM +try: + if not (utils.is_torch_available() and utils.is_triton_available()): raise exceptions.MissingDependencyError +except exceptions.MissingDependencyError: + if "utils.dummy_pt_objects" in _import_structure: _import_structure["utils.dummy_pt_objects"].extend(["MPT"]) + else: _import_structure["utils.dummy_pt_objects"] = ["MPT"] +else: + _import_structure["models.mpt"].extend(["MPT"]) + if _t.TYPE_CHECKING: from .models.mpt import MPT as MPT +try: + if not (utils.is_torch_available() and utils.is_einops_available()): raise exceptions.MissingDependencyError +except exceptions.MissingDependencyError: + if "utils.dummy_pt_objects" in _import_structure: _import_structure["utils.dummy_pt_objects"].extend(["Falcon"]) + else: _import_structure["utils.dummy_pt_objects"] = ["Falcon"] +else: + _import_structure["models.falcon"].extend(["Falcon"]) + if _t.TYPE_CHECKING: from .models.falcon import Falcon as Falcon + try: if not utils.is_torch_available(): raise exceptions.MissingDependencyError except exceptions.MissingDependencyError: - _import_structure["utils.dummy_pt_objects"] = utils.dummy_pt_objects.__all__ + _import_structure["utils.dummy_pt_objects"] = [name for name in dir(utils.dummy_pt_objects) if not name.startswith("_") and name not in ("ChatGLM", "Baichuan", "MPT", "Falcon", "annotations")] else: - if utils.is_cpm_kernels_available(): - _import_structure["models.chatglm"].extend(["ChatGLM"]) - _import_structure["models.baichuan"].extend(["Baichuan"]) - if _t.TYPE_CHECKING: - from .models.baichuan import Baichuan as Baichuan - from .models.chatglm import ChatGLM as ChatGLM - if utils.is_einops_available(): - _import_structure["models.falcon"].extend(["Falcon"]) - if _t.TYPE_CHECKING: - from .models.falcon import Falcon as Falcon - if utils.is_triton_available(): - _import_structure["models.mpt"].extend(["MPT"]) - if _t.TYPE_CHECKING: - from .models.mpt import MPT as MPT _import_structure["models.flan_t5"].extend(["FlanT5"]) _import_structure["models.dolly_v2"].extend(["DollyV2"]) _import_structure["models.starcoder"].extend(["StarCoder"]) @@ -100,7 +108,7 @@ else: try: if not utils.is_vllm_available(): raise exceptions.MissingDependencyError except exceptions.MissingDependencyError: - _import_structure["utils.dummy_vllm_objects"] = utils.dummy_vllm_objects.__all__ + _import_structure["utils.dummy_vllm_objects"] = [name for name in dir(utils.dummy_vllm_objects) if not name.startswith("_") and name not in ("annotations",)] else: _import_structure["models.baichuan"].extend(["VLLMBaichuan"]) _import_structure["models.llama"].extend(["VLLMLlama"]) @@ -124,7 +132,7 @@ else: try: if not utils.is_flax_available(): raise exceptions.MissingDependencyError except exceptions.MissingDependencyError: - _import_structure["utils.dummy_flax_objects"] = utils.dummy_flax_objects.__all__ + _import_structure["utils.dummy_flax_objects"] = [name for name in dir(utils.dummy_flax_objects) if not name.startswith("_") and name not in ("annotations",)] else: _import_structure["models.flan_t5"].extend(["FlaxFlanT5"]) _import_structure["models.opt"].extend(["FlaxOPT"]) @@ -136,7 +144,7 @@ else: try: if not utils.is_tf_available(): raise exceptions.MissingDependencyError except exceptions.MissingDependencyError: - _import_structure["utils.dummy_tf_objects"] = utils.dummy_tf_objects.__all__ + _import_structure["utils.dummy_tf_objects"] = [name for name in dir(utils.dummy_tf_objects) if not name.startswith("_") and name not in ("annotations",)] else: _import_structure["models.flan_t5"].extend(["TFFlanT5"]) _import_structure["models.opt"].extend(["TFOPT"]) diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index 5652ca09..682cfcdb 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="attr-defined,no-untyped-call,type-var,operator,arg-type,no-redef" """Configuration utilities for OpenLLM. All model configuration will inherit from ``openllm.LLMConfig``. Highlight feature: Each fields in ``openllm.LLMConfig`` will also automatically generate a environment @@ -33,29 +34,16 @@ dynamically during serve, ahead-of-serve or per requests. Refer to ``openllm.LLMConfig`` docstring for more information. """ from __future__ import annotations -import copy -import enum -import logging -import os -import sys -import types -import typing as t - -import attr -import click_option_group as cog -import inflection -import orjson +import copy, enum, logging, os, sys, types, typing as t +import attr, click_option_group as cog, inflection, orjson, openllm from cattr.gen import make_dict_structure_fn, make_dict_unstructure_fn, override from deepmerge.merger import Merger - -import openllm - from ._strategies import LiteralResourceSpec, available_resource_spec, resource_spec +from ._typing_compat import LiteralString, NotRequired, Required, overload, AdapterType, LiteralRuntime from .exceptions import ForbiddenAttributeError from .utils import ( ENV_VARS_TRUE_VALUES, MYPY, - LazyType, ReprMixin, bentoml_cattr, codegen, @@ -63,43 +51,18 @@ from .utils import ( field_env_key, first_not_none, lenient_issubclass, - non_intrusive_setattr, ) from .utils.import_utils import BACKENDS_MAPPING - -# NOTE: We need to do check overload import -# so that it can register -# correct overloads to typing registry -if sys.version_info[:2] >= (3, 11): - from typing import NotRequired, Required, dataclass_transform, overload -else: - from typing_extensions import NotRequired, Required, dataclass_transform, overload - -# NOTE: Using internal API from attr here, since we are actually -# allowing subclass of openllm.LLMConfig to become 'attrs'-ish +# NOTE: Using internal API from attr here, since we are actually allowing subclass of openllm.LLMConfig to become 'attrs'-ish from attr._compat import set_closure_cell from attr._make import _CountingAttr, _make_init, _transform_attrs - -_T = t.TypeVar("_T") - -LiteralRuntime = t.Literal["pt", "tf", "flax", "vllm"] +from ._typing_compat import AnyCallable, At, Self, ListStr, DictStrAny if t.TYPE_CHECKING: - import click - import peft - import transformers - import vllm + import click, peft, transformers, vllm from transformers.generation.beam_constraints import Constraint - - from ._types import AnyCallable, At - - DictStrAny = dict[str, t.Any] - ListStr = list[str] else: Constraint = t.Any - ListStr = list - DictStrAny = dict - vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm") transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") peft = openllm.utils.LazyLoader("peft", globals(), "peft") @@ -107,7 +70,7 @@ else: __all__ = ["LLMConfig", "GenerationConfig", "SamplingParams"] logger = logging.getLogger(__name__) -config_merger = Merger([(DictStrAny, "merge")], ["override"], ["override"]) +config_merger = Merger([(dict, "merge")], ["override"], ["override"]) # case insensitive, but rename to conform with type class _PeftEnumMeta(enum.EnumMeta): @@ -141,8 +104,6 @@ class PeftType(str, enum.Enum, metaclass=_PeftEnumMeta): _PEFT_TASK_TYPE_TARGET_MAPPING = {"causal_lm": "CAUSAL_LM", "seq2seq_lm": "SEQ_2_SEQ_LM"} -AdapterType = t.Literal["lora", "adalora", "adaption_prompt", "prefix_tuning", "p_tuning", "prompt_tuning", "ia3"] - _object_setattr = object.__setattr__ def _adapter_converter(value: AdapterType | str | PeftType | None) -> PeftType: @@ -263,8 +224,7 @@ class GenerationConfig(ReprMixin): if t.TYPE_CHECKING and not MYPY: # stubs this for pyright as mypy already has a attr plugin builtin - def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None: - ... + def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None: ... def __init__(self, *, _internal: bool = False, **attrs: t.Any): if not _internal: raise RuntimeError("GenerationConfig is not meant to be used directly, but you can access this via a LLMConfig.generation_config") @@ -322,7 +282,7 @@ class SamplingParams(ReprMixin): def to_vllm(self) -> vllm.SamplingParams: return vllm.SamplingParams(max_tokens=self.max_tokens, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, **bentoml_cattr.unstructure(self)) @classmethod - def from_generation_config(cls, generation_config: GenerationConfig, **attrs: t.Any) -> t.Self: + def from_generation_config(cls, generation_config: GenerationConfig, **attrs: t.Any) -> Self: """The main entrypoint for creating a SamplingParams from ``openllm.LLMConfig``.""" stop = attrs.pop("stop", None) if stop is not None and isinstance(stop, str): stop = [stop] @@ -480,7 +440,7 @@ bentoml_cattr.register_structure_hook(_ModelSettingsAttr, structure_settings) def _setattr_class(attr_name: str, value_var: t.Any) -> str: return f"setattr(cls, '{attr_name}', {value_var})" -def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance, _prefix: t.LiteralString = "openllm") -> t.Callable[..., None]: +def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance, _prefix: LiteralString = "openllm") -> t.Callable[..., None]: """Generate the assignment script with prefix attributes __openllm___.""" args: ListStr = [] globs: DictStrAny = {"cls": cls, "_cached_attribute": attributes, "_cached_getattribute_get": _object_getattribute.__get__} @@ -497,11 +457,6 @@ def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance _reserved_namespace = {"__config__", "GenerationConfig", "SamplingParams"} -@dataclass_transform(kw_only_default=True, order_default=True, field_specifiers=(attr.field, dantic.Field)) -def llm_config_transform(cls: type[LLMConfig]) -> type[LLMConfig]: - non_intrusive_setattr(cls, "__dataclass_transform__", {"order_default": True, "kw_only_default": True, "field_specifiers": (attr.field, dantic.Field)}) - return cls - @attr.define(slots=True) class _ConfigAttr: Field = dantic.Field @@ -669,7 +624,7 @@ class _ConfigBuilder: __slots__ = ("_cls", "_cls_dict", "_attr_names", "_attrs", "_model_name", "_base_attr_map", "_base_names", "_has_pre_init", "_has_post_init") - def __init__(self, cls: type[LLMConfig], these: dict[str, _CountingAttr[t.Any]], auto_attribs: bool = False, kw_only: bool = False, collect_by_mro: bool = True): + def __init__(self, cls: type[LLMConfig], these: dict[str, _CountingAttr], auto_attribs: bool = False, kw_only: bool = False, collect_by_mro: bool = True): attrs, base_attrs, base_attr_map = _transform_attrs(cls, these, auto_attribs, kw_only, collect_by_mro, field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__)) self._cls, self._model_name, self._cls_dict, self._attrs, self._base_names, self._base_attr_map = cls, cls.__openllm_model_name__, dict(cls.__dict__), attrs, {a.name for a in base_attrs}, base_attr_map self._attr_names = tuple(a.name for a in attrs) @@ -742,17 +697,16 @@ class _ConfigBuilder: if not closure_cells: continue # Catch None or the empty list. for cell in closure_cells: try: match = cell.cell_contents is self._cls - except ValueError: pass # ValueError: Cell is empty + except ValueError: pass # noqa: PERF203 # ValueError: Cell is empty else: if match: set_closure_cell(cell, cls) + return cls - return llm_config_transform(cls) - - def add_attrs_init(self) -> t.Self: + def add_attrs_init(self) -> Self: self._cls_dict["__attrs_init__"] = codegen.add_method_dunders(self._cls, _make_init(self._cls, self._attrs, self._has_pre_init, self._has_post_init, False, True, False, self._base_attr_map, False, None, True)) return self - def add_repr(self) -> t.Self: + def add_repr(self) -> Self: for key, fn in ReprMixin.__dict__.items(): if key in ("__repr__", "__str__", "__repr_name__", "__repr_str__", "__repr_args__"): self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn) self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs} | {"generation_config", "sampling_config"}) @@ -865,7 +819,7 @@ class LLMConfig(_ConfigAttr): # auto assignment attributes generated from __config__ after create the new slot class. _make_assignment_script(cls, bentoml_cattr.structure(cls, _ModelSettingsAttr))(cls) - def _make_subclass(class_attr: str, base: type[At], globs: dict[str, t.Any] | None = None, suffix_env: t.LiteralString | None = None) -> type[At]: + def _make_subclass(class_attr: str, base: type[At], globs: dict[str, t.Any] | None = None, suffix_env: LiteralString | None = None) -> type[At]: camel_name = cls.__name__.replace("Config", "") klass = attr.make_class(f"{camel_name}{class_attr}", [], bases=(base,), slots=True, weakref_slot=True, frozen=True, repr=False, init=False, collect_by_mro=True, field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__, suffix=suffix_env, globs=globs, default_callback=lambda field_name, field_default: getattr(getattr(cls, class_attr), field_name, field_default) if codegen.has_own_attribute(cls, class_attr) else field_default)) # For pickling to work, the __module__ variable needs to be set to the @@ -882,19 +836,19 @@ class LLMConfig(_ConfigAttr): anns = codegen.get_annotations(cls) # _CountingAttr is the underlying representation of attr.field ca_names = {name for name, attr in cd.items() if isinstance(attr, _CountingAttr)} - these: dict[str, _CountingAttr[t.Any]] = {} + these: dict[str, _CountingAttr] = {} annotated_names: set[str] = set() for attr_name, typ in anns.items(): if codegen.is_class_var(typ): continue annotated_names.add(attr_name) val = cd.get(attr_name, attr.NOTHING) - if not LazyType["_CountingAttr[t.Any]"](_CountingAttr).isinstance(val): + if not isinstance(val, _CountingAttr): if val is attr.NOTHING: val = cls.Field(env=field_env_key(cls.__openllm_model_name__, attr_name)) else: val = cls.Field(default=val, env=field_env_key(cls.__openllm_model_name__, attr_name)) these[attr_name] = val unannotated = ca_names - annotated_names if len(unannotated) > 0: - missing_annotated = sorted(unannotated, key=lambda n: t.cast("_CountingAttr[t.Any]", cd.get(n)).counter) + missing_annotated = sorted(unannotated, key=lambda n: t.cast("_CountingAttr", cd.get(n)).counter) raise openllm.exceptions.MissingAnnotationAttributeError(f"The following field doesn't have a type annotation: {missing_annotated}") # We need to set the accepted key before generation_config # as generation_config is a special field that users shouldn't pass. @@ -1102,7 +1056,7 @@ class LLMConfig(_ConfigAttr): def __getitem__(self, item: t.Literal["ia3"]) -> dict[str, t.Any]: ... # update-config-stubs.py: stop - def __getitem__(self, item: t.LiteralString | t.Any) -> t.Any: + def __getitem__(self, item: LiteralString | t.Any) -> t.Any: """Allowing access LLMConfig as a dictionary. The order will always evaluate as. __openllm_*__ > self.key > self.generation_config > self['fine_tune_strategies'] > __openllm_extras__ @@ -1179,13 +1133,13 @@ class LLMConfig(_ConfigAttr): def model_dump_json(self, **kwargs: t.Any) -> bytes: return orjson.dumps(self.model_dump(**kwargs)) @classmethod - def model_construct_json(cls, json_str: str | bytes) -> t.Self: + def model_construct_json(cls, json_str: str | bytes) -> Self: try: attrs = orjson.loads(json_str) except orjson.JSONDecodeError as err: raise openllm.exceptions.ValidationError(f"Failed to load JSON: {err}") from None return bentoml_cattr.structure(attrs, cls) @classmethod - def model_construct_env(cls, **attrs: t.Any) -> t.Self: + def model_construct_env(cls, **attrs: t.Any) -> Self: """A helpers that respect configuration values environment variables.""" attrs = {k: v for k, v in attrs.items() if v is not None} model_config = cls.__openllm_env__.config @@ -1198,7 +1152,7 @@ class LLMConfig(_ConfigAttr): if "generation_config" in attrs: generation_config = attrs.pop("generation_config") - if not LazyType(DictStrAny).isinstance(generation_config): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}") + if not isinstance(generation_config, dict): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}") else: generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(cls.__openllm_generation_class__)} for k in tuple(attrs.keys()): @@ -1258,7 +1212,7 @@ class LLMConfig(_ConfigAttr): total_keys = set(attr.fields_dict(cls.__openllm_generation_class__)) | set(attr.fields_dict(cls.__openllm_sampling_class__)) - if len(cls.__openllm_accepted_keys__.difference(total_keys)) == 0: return f + if len(cls.__openllm_accepted_keys__.difference(total_keys)) == 0: return t.cast("click.Command", f) # We pop out 'generation_config' as it is a attribute that we don't need to expose to CLI. for name, field in attr.fields_dict(cls).items(): ty = cls.__openllm_hints__.get(name) @@ -1285,13 +1239,13 @@ def structure_llm_config(data: DictStrAny, cls: type[LLMConfig]) -> LLMConfig: Otherwise, we will filter out all keys are first in LLMConfig, parse it in, then parse the remaining keys into LLMConfig.generation_config """ - if not LazyType(DictStrAny).isinstance(data): raise RuntimeError(f"Expected a dictionary, but got {type(data)}") + if not isinstance(data, dict): raise RuntimeError(f"Expected a dictionary, but got {type(data)}") cls_attrs = {k: v for k, v in data.items() if k in cls.__openllm_accepted_keys__} generation_cls_fields = attr.fields_dict(cls.__openllm_generation_class__) if "generation_config" in data: generation_config = data.pop("generation_config") - if not LazyType(DictStrAny).isinstance(generation_config): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}") + if not isinstance(generation_config, dict): raise RuntimeError(f"Expected a dictionary, but got {type(generation_config)}") config_merger.merge(generation_config, {k: v for k, v in data.items() if k in generation_cls_fields}) else: generation_config = {k: v for k, v in data.items() if k in generation_cls_fields} diff --git a/src/openllm/_generation.py b/src/openllm/_generation.py index ad8328c4..bf00d0b0 100644 --- a/src/openllm/_generation.py +++ b/src/openllm/_generation.py @@ -1,17 +1,9 @@ -"""Generation utilities to be reused throughout.""" from __future__ import annotations -import typing as t - -import transformers - -if t.TYPE_CHECKING: - import torch - - import openllm +import typing as t, transformers +if t.TYPE_CHECKING: import torch, openllm LogitsProcessorList = transformers.LogitsProcessorList StoppingCriteriaList = transformers.StoppingCriteriaList - class StopSequenceCriteria(transformers.StoppingCriteria): def __init__(self, stop_sequences: str | list[str], tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerBase | transformers.PreTrainedTokenizerFast): if isinstance(stop_sequences, str): stop_sequences = [stop_sequences] diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 8fad7b83..8fd78d06 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -1,32 +1,13 @@ from __future__ import annotations -import collections -import functools -import inspect -import logging -import os -import re -import sys -import traceback -import types -import typing as t -import uuid +import functools, inspect, logging, os, re, traceback, types, typing as t, uuid from abc import ABC, abstractmethod from pathlib import Path - -import attr -import fs.path -import inflection -import orjson +import attr, fs.path, inflection, orjson, bentoml, openllm from huggingface_hub import hf_hub_download - -import bentoml -import openllm from bentoml._internal.models.model import ModelSignature from ._configuration import ( - AdapterType, FineTuneConfig, - LiteralRuntime, LLMConfig, _object_getattribute, _setattr_class, @@ -57,54 +38,37 @@ from .utils import ( validate_is_path, ) -# NOTE: We need to do this so that overload can register -# correct overloads to typing registry -if sys.version_info[:2] >= (3, 11): - from typing import NotRequired, overload -else: - from typing_extensions import NotRequired, overload +from ._typing_compat import ( + AdaptersMapping, + AdaptersTuple, + AnyCallable, + AdapterType, + LiteralRuntime, + DictStrAny, + ListStr, + LLMEmbeddings, + LLMRunnable, + LLMRunner, + ModelSignatureDict as _ModelSignatureDict, + PeftAdapterOutput, + TupleAny, + NotRequired, overload, M, T, LiteralString +) if t.TYPE_CHECKING: - import auto_gptq as autogptq - import peft - import torch - import transformers - import vllm - + import auto_gptq as autogptq, peft, torch, transformers, vllm from ._configuration import PeftType - from ._types import ( - AdaptersMapping, - AdaptersTuple, - AnyCallable, - DictStrAny, - ListStr, - LLMEmbeddings, - LLMRunnable, - LLMRunner, - ModelSignatureDict as _ModelSignatureDict, - PeftAdapterOutput, - TupleAny, - ) from .utils.representation import ReprArgs - - UserDictAny = collections.UserDict[str, t.Any] - ResolvedAdaptersMapping = dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] else: - DictStrAny = dict - TupleAny = tuple - UserDictAny = collections.UserDict - LLMRunnable = bentoml.Runnable - LLMRunner = bentoml.Runner - LLMEmbeddings = dict - autogptq = LazyLoader("autogptq", globals(), "auto_gptq") vllm = LazyLoader("vllm", globals(), "vllm") transformers = LazyLoader("transformers", globals(), "transformers") torch = LazyLoader("torch", globals(), "torch") peft = LazyLoader("peft", globals(), "peft") -logger = logging.getLogger(__name__) +ResolvedAdaptersMapping = t.Dict[AdapterType, t.Dict[t.Union[str, t.Literal["default"]], t.Tuple["peft.PeftConfig", str]]] +logger = logging.getLogger(__name__) class ModelSignatureDict(t.TypedDict, total=False): batchable: bool batch_dim: t.Union[t.Tuple[int, int], int] @@ -150,9 +114,6 @@ def resolve_peft_config_type(adapter_map: dict[str, str | None]) -> AdaptersMapp _reserved_namespace = {"config_class", "model", "tokenizer", "import_kwargs"} -M = t.TypeVar("M", bound="t.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.LLMEngine, vllm.AsyncLLMEngine, peft.PeftModel, autogptq.modeling.BaseGPTQForCausalLM]") -T = t.TypeVar("T", bound="t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]") - class LLMInterface(ABC, t.Generic[M, T]): """This defines the loose contract for all openllm.LLM implementations.""" @property @@ -257,7 +218,7 @@ class LLMInterface(ABC, t.Generic[M, T]): raise NotImplementedError # NOTE: All fields below are attributes that can be accessed by users. - config_class: type[LLMConfig] + config_class: t.Type[LLMConfig] """The config class to use for this LLM. If you are creating a custom LLM, you must specify this class.""" bettertransformer: bool """Whether to load this LLM with FasterTransformer enabled. The order of loading is: @@ -270,7 +231,7 @@ class LLMInterface(ABC, t.Generic[M, T]): """ device: "torch.device" """The device to be used for this LLM. If the implementation is 'pt', then it will be torch.device, else string.""" - tokenizer_id: t.LiteralString | t.Literal["local"] + tokenizer_id: t.Union[t.Literal["local"], LiteralString] """optional tokenizer_id for loading with vLLM if the model supports vLLM.""" # NOTE: The following will be populated by __init_subclass__, note that these should be immutable. __llm_trust_remote_code__: bool @@ -290,13 +251,13 @@ class LLMInterface(ABC, t.Generic[M, T]): An additional naming for all VLLM backend: VLLMLlama -> `vllm` """ - __llm_model__: M | None + __llm_model__: t.Optional[M] """A reference to the actual model. Instead of access this directly, you should use `model` property instead.""" - __llm_tokenizer__: T | None + __llm_tokenizer__: t.Optional[T] """A reference to the actual tokenizer. Instead of access this directly, you should use `tokenizer` property instead.""" - __llm_bentomodel__: bentoml.Model | None + __llm_bentomodel__: t.Optional[bentoml.Model] """A reference to the bentomodel used for this LLM. Instead of access this directly, you should use `_bentomodel` property instead.""" - __llm_adapter_map__: dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None + __llm_adapter_map__: t.Optional[ResolvedAdaptersMapping] """A reference to the the cached LoRA adapter mapping.""" __llm_supports_embeddings__: bool """A boolean to determine whether models does implement ``LLM.embeddings``.""" @@ -307,11 +268,10 @@ class LLMInterface(ABC, t.Generic[M, T]): __llm_supports_generate_iterator__: bool """A boolean to determine whether models does implement ``LLM.generate_iterator``.""" if t.TYPE_CHECKING and not MYPY: - def __attrs_init__(self, config: LLMConfig, quantization_config: transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig | None, model_id: str, runtime: t.Literal["ggml", "transformers"], model_decls: TupleAny, model_attrs: DictStrAny, tokenizer_attrs: DictStrAny, tag: bentoml.Tag, adapters_mapping: AdaptersMapping | None, model_version: str | None, quantize_method: t.Literal["int8", "int4", "gptq"] | None, serialisation_format: t.Literal["safetensors", "legacy"], _local: bool, **attrs: t.Any) -> None: + def __attrs_init__(self, config: LLMConfig, quantization_config: t.Optional[t.Union[transformers.BitsAndBytesConfig, autogptq.BaseQuantizeConfig]], model_id: str, runtime: t.Literal["ggml", "transformers"], model_decls: TupleAny, model_attrs: DictStrAny, tokenizer_attrs: DictStrAny, tag: bentoml.Tag, adapters_mapping: t.Optional[AdaptersMapping], model_version: t.Optional[str], quantize_method: t.Optional[t.Literal["int8", "int4", "gptq"]], serialisation_format: t.Literal["safetensors", "legacy"], _local: bool, **attrs: t.Any) -> None: """Generated __attrs_init__ for openllm.LLM.""" _R = t.TypeVar("_R", covariant=True) - class _import_model_wrapper(t.Generic[_R, M, T], t.Protocol): def __call__(self, llm: LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> _R: ... class _load_model_wrapper(t.Generic[M, T], t.Protocol): @@ -333,7 +293,7 @@ def _wrapped_import_model(f: _import_model_wrapper[bentoml.Model, M, T]) -> t.Ca (model_decls, model_attrs), _ = self.llm_parameters decls = (*model_decls, *decls) attrs = {**model_attrs, **attrs} - return f(self, *decls, trust_remote_code=t.cast(bool, trust_remote_code), **attrs) + return f(self, *decls, trust_remote_code=trust_remote_code, **attrs) return wrapper _DEFAULT_TOKENIZER = "hf-internal-testing/llama-tokenizer" @@ -483,7 +443,7 @@ class LLM(LLMInterface[M, T], ReprMixin): @overload def __getitem__(self, item: t.Literal["bentomodel"]) -> bentoml.Model | None: ... @overload - def __getitem__(self, item: t.Literal["adapter_map"]) -> dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None: ... + def __getitem__(self, item: t.Literal["adapter_map"]) -> ResolvedAdaptersMapping | None: ... @overload def __getitem__(self, item: t.Literal["supports_embeddings"]) -> bool: ... @overload @@ -492,7 +452,7 @@ class LLM(LLMInterface[M, T], ReprMixin): def __getitem__(self, item: t.Literal["supports_generate_one"]) -> bool: ... @overload def __getitem__(self, item: t.Literal["supports_generate_iterator"]) -> bool: ... - def __getitem__(self, item: t.LiteralString | t.Any) -> t.Any: + def __getitem__(self, item: t.Union[LiteralString, t.Any]) -> t.Any: if item is None: raise TypeError(f"{self} doesn't understand how to index None.") item = inflection.underscore(item) internal_attributes = f"__llm_{item}__" @@ -575,7 +535,7 @@ class LLM(LLMInterface[M, T], ReprMixin): """ cfg_cls = cls.config_class _local = False - _model_id: str = t.cast(str, first_not_none(model_id, os.environ.get(cfg_cls.__openllm_env__["model_id"]), default=cfg_cls.__openllm_default_id__)) + _model_id: str = first_not_none(model_id, os.environ.get(cfg_cls.__openllm_env__["model_id"]), default=cfg_cls.__openllm_default_id__) if validate_is_path(_model_id): _model_id, _local = resolve_filepath(_model_id), True quantize = first_not_none(quantize, t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], os.environ.get(cfg_cls.__openllm_env__["quantize"])), default=None) @@ -889,7 +849,6 @@ class LLM(LLMInterface[M, T], ReprMixin): adapter_mapping = _mapping[adapter_type] self.__llm_model__ = self._wrap_default_peft_model(adapter_mapping, inference_mode=inference_mode) - # now we loop through the rest with add_adapter if len(adapter_mapping) > 0: for adapter_name, (_peft_config, _) in adapter_mapping.items(): @@ -1108,11 +1067,10 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate def llm_runner_class(self: LLM[M, T]) -> type[LLMRunner[M, T]]: def available_adapters(_: LLMRunner[M, T]) -> PeftAdapterOutput: - if not is_peft_available(): return {"success": False, "result": {}, "error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'"} - if self.__llm_adapter_map__ is None: return {"success": False, "result": {}, "error_msg": "No adapters available for current running server."} - if not isinstance(self.model, peft.PeftModel): return {"success": False, "result": {}, "error_msg": "Model is not a PeftModel"} - return {"success": True, "result": self.model.peft_config, "error_msg": ""} - + if not is_peft_available(): return PeftAdapterOutput(success=False, result={}, error_msg="peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'") + if self.__llm_adapter_map__ is None: return PeftAdapterOutput(success=False, result={}, error_msg="No adapters available for current running server.") + if not isinstance(self.model, peft.PeftModel): return PeftAdapterOutput(success=False, result={}, error_msg="Model is not a PeftModel") + return PeftAdapterOutput(success=True, result=self.model.peft_config, error_msg="") def _wrapped_generate_run(__self: LLMRunner[M, T], prompt: str, **kwargs: t.Any) -> t.Any: """Wrapper for runner.generate.run() to handle the prompt and postprocessing. diff --git a/src/openllm/_prompt.py b/src/openllm/_prompt.py index 2b666675..54c4494e 100644 --- a/src/openllm/_prompt.py +++ b/src/openllm/_prompt.py @@ -1,22 +1,17 @@ from __future__ import annotations -import string -import typing as t - +import string, typing as t class PromptFormatter(string.Formatter): """This PromptFormatter is largely based on langchain's implementation.""" def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any: if len(args) > 0: raise ValueError("Positional arguments are not supported") return super().vformat(format_string, args, kwargs) - def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None: extras = set(kwargs).difference(used_args) if extras: raise KeyError(f"Extra params passed: {extras}") - def extract_template_variables(self, template: str) -> t.Sequence[str]: return [field[1] for field in self.parse(template) if field[1] is not None] default_formatter = PromptFormatter() - def process_prompt(prompt: str, template: str | None = None, use_prompt_template: bool = True, **attrs: t.Any) -> str: # Currently, all default prompt will always have `instruction` key. if not use_prompt_template: return prompt @@ -24,7 +19,5 @@ def process_prompt(prompt: str, template: str | None = None, use_prompt_template template_variables = default_formatter.extract_template_variables(template) prompt_variables = {k: v for k, v in attrs.items() if k in template_variables} if "instruction" in prompt_variables: raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'") - try: - return template.format(instruction=prompt, **prompt_variables) - except KeyError as e: - raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None + try: return template.format(instruction=prompt, **prompt_variables) + except KeyError as e: raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None diff --git a/src/openllm/_quantisation.py b/src/openllm/_quantisation.py index 93b6f14a..b1123d4a 100644 --- a/src/openllm/_quantisation.py +++ b/src/openllm/_quantisation.py @@ -1,18 +1,9 @@ # mypy: disable-error-code="name-defined" from __future__ import annotations -import logging -import sys -import typing as t - +import logging, sys, typing as t from .utils import LazyLoader, is_autogptq_available, is_bitsandbytes_available, is_transformers_supports_kbit, pkg - -# NOTE: We need to do this so that overload can register -# correct overloads to typing registry -if sys.version_info[:2] >= (3, 11): - from typing import overload -else: - from typing_extensions import overload - +if sys.version_info[:2] >= (3, 11): from typing import overload +else: from typing_extensions import overload if t.TYPE_CHECKING: from ._llm import LLM from ._types import DictStrAny diff --git a/src/openllm/_schema.py b/src/openllm/_schema.py index 2f3c1c02..e5ddad2a 100644 --- a/src/openllm/_schema.py +++ b/src/openllm/_schema.py @@ -1,18 +1,10 @@ """Schema definition for OpenLLM. This can be use for client interaction.""" from __future__ import annotations -import functools -import typing as t - -import attr -import inflection - -import openllm - +import functools, typing as t +import attr, inflection, openllm from ._configuration import GenerationConfig, LLMConfig from .utils import bentoml_cattr - -if t.TYPE_CHECKING: - import vllm +if t.TYPE_CHECKING: import vllm @attr.frozen(slots=True) class GenerationInput: @@ -30,7 +22,6 @@ class GenerationInput: def for_model(cls, model_name: str, **attrs: t.Any) -> type[GenerationInput]: return cls.from_llm_config(openllm.AutoConfig.for_model(model_name, **attrs)) @classmethod def from_llm_config(cls, llm_config: openllm.LLMConfig) -> type[GenerationInput]: return attr.make_class(inflection.camelize(llm_config["model_name"]) + "GenerationInput", attrs={"prompt": attr.field(type=str), "llm_config": attr.field(type=llm_config.__class__, default=llm_config, converter=functools.partial(cls.convert_llm_config, cls=llm_config.__class__)), "adapter_name": attr.field(default=None, type=str)}) - @attr.frozen(slots=True) class GenerationOutput: responses: t.List[t.Any] @@ -43,7 +34,6 @@ class GenerationOutput: if hasattr(self, key): return getattr(self, key) elif key in self.configuration: return self.configuration[key] else: raise KeyError(key) - @attr.frozen(slots=True) class MetadataOutput: model_id: str @@ -53,14 +43,11 @@ class MetadataOutput: configuration: str supports_embeddings: bool supports_hf_agent: bool - @attr.frozen(slots=True) class EmbeddingsOutput: embeddings: t.List[t.List[float]] num_tokens: int - def unmarshal_vllm_outputs(request_output: vllm.RequestOutput) -> dict[str, t.Any]: return dict(request_id=request_output.request_id, prompt=request_output.prompt, finished=request_output.finished, prompt_token_ids=request_output.prompt_token_ids, outputs=[dict(index=it.index, text=it.text, token_ids=it.token_ids, cumulative_logprob=it.cumulative_logprob, logprobs=it.logprobs, finish_reason=it.finish_reason) for it in request_output.outputs]) - @attr.define class HfAgentInput: inputs: str diff --git a/src/openllm/_service.py b/src/openllm/_service.py index 41d56200..7c80ee61 100644 --- a/src/openllm/_service.py +++ b/src/openllm/_service.py @@ -1,13 +1,11 @@ +# mypy: disable-error-code="arg-type,misc" """The service definition for running any LLMService. -Note that the line `model = ...` is a special line and should not be modified. This will be handled by openllm -internally to generate the correct model service when bundling the LLM to a Bento. -This will ensure that 'bentoml serve llm-bento' will work accordingly. - -The generation code lives under utils/codegen.py +For line with comment '# openllm: ...', it must not be modified as it is managed internally by OpenLLM. +Codegen can be found under 'openllm.utils.codegen' """ from __future__ import annotations -import os, typing as t, warnings, orjson, bentoml, openllm +import os, warnings, orjson, bentoml, openllm, typing as t from starlette.applications import Starlette from starlette.responses import JSONResponse from starlette.routing import Route @@ -24,48 +22,31 @@ llm_config = openllm.AutoConfig.for_model(model) runner = openllm.Runner(model, llm_config=llm_config, ensure_available=False, adapter_map=orjson.loads(adapter_map)) svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[runner]) -@svc.api(route="/v1/generate", input=bentoml.io.JSON.from_sample({"prompt": "", "llm_config": llm_config.model_dump(flatten=True)}), # type: ignore[arg-type] # XXX: remove once JSON supports Attrs class - output=bentoml.io.JSON.from_sample({"responses": [], "configuration": llm_config.model_dump(flatten=True)})) +@svc.api(route="/v1/generate", input=bentoml.io.JSON.from_sample({"prompt": "", "llm_config": llm_config.model_dump(flatten=True)}), output=bentoml.io.JSON.from_sample({"responses": [], "configuration": llm_config.model_dump(flatten=True)})) async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput: qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict) config = qa_inputs.llm_config.model_dump() responses = await runner.generate.async_run(qa_inputs.prompt, **{"adapter_name": qa_inputs.adapter_name, **config}) return openllm.GenerationOutput(responses=responses, configuration=config) -@svc.api( - route="/v1/metadata", input=bentoml.io.Text(), # type: ignore[misc] # XXX: remove once JSON supports Attrs class - output=bentoml.io.JSON.from_sample({"model_id": runner.llm.model_id, "timeout": 3600, "model_name": llm_config["model_name"], "framework": "pt", "configuration": "", "supports_embeddings": runner.supports_embeddings, "supports_hf_agent": runner.supports_hf_agent}) -) +@svc.api(route="/v1/metadata", input=bentoml.io.Text(), output=bentoml.io.JSON.from_sample({"model_id": runner.llm.model_id, "timeout": 3600, "model_name": llm_config["model_name"], "framework": "pt", "configuration": "", "supports_embeddings": runner.supports_embeddings, "supports_hf_agent": runner.supports_hf_agent})) def metadata_v1(_: str) -> openllm.MetadataOutput: - return openllm.MetadataOutput(model_id=runner.llm.model_id, timeout=llm_config["timeout"], model_name=llm_config["model_name"], framework=llm_config["env"]["framework_value"], configuration=llm_config.model_dump_json().decode(), supports_embeddings=runner.supports_embeddings, supports_hf_agent=runner.supports_hf_agent,) + return openllm.MetadataOutput(timeout=llm_config["timeout"], model_name=llm_config["model_name"], framework=llm_config["env"]["framework_value"], model_id=runner.llm.model_id, configuration=llm_config.model_dump_json().decode(), supports_embeddings=runner.supports_embeddings, supports_hf_agent=runner.supports_hf_agent) if runner.supports_embeddings: - - @svc.api( # type: ignore[arg-type] # XXX: remove once JSON supports Attrs class - input=bentoml.io.JSON.from_sample(["Hey Jude, welcome to the jungle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample({ - "embeddings": [ - 0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918, - 0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076 - ], "num_tokens": 20 - }), route="/v1/embeddings" - ) + @svc.api(route="/v1/embeddings", input=bentoml.io.JSON.from_sample(["Hey Jude, welcome to the jungle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample({"embeddings": [0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918, 0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076], "num_tokens": 20})) async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput: responses = await runner.embeddings.async_run(phrases) return openllm.EmbeddingsOutput(embeddings=responses["embeddings"], num_tokens=responses["num_tokens"]) if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent(): - async def hf_agent(request: Request) -> Response: json_str = await request.body() - try: - input_data = openllm.utils.bentoml_cattr.structure(orjson.loads(json_str), openllm.HfAgentInput) - except orjson.JSONDecodeError as err: - raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None + try: input_data = openllm.utils.bentoml_cattr.structure(orjson.loads(json_str), openllm.HfAgentInput) + except orjson.JSONDecodeError as err: raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None stop = input_data.parameters.pop("stop", ["\n"]) - try: - return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200) - except NotImplementedError: - return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500) + try: return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200) + except NotImplementedError: return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500) hf_app = Starlette(debug=True, routes=[Route("/agent", hf_agent, methods=["POST"])]) svc.mount_asgi_app(hf_app, path="/hf") diff --git a/src/openllm/_strategies.py b/src/openllm/_strategies.py index 43642396..7a71e9d1 100644 --- a/src/openllm/_strategies.py +++ b/src/openllm/_strategies.py @@ -1,38 +1,17 @@ - from __future__ import annotations -import functools -import inspect -import logging -import math -import os -import sys -import types -import typing as t -import warnings - -import psutil - -import bentoml +import functools, inspect, logging, math, os, sys, types, typing as t, warnings, psutil, bentoml from bentoml._internal.resource import get_resource, system_resources from bentoml._internal.runner.strategy import THREAD_ENVS - from .utils import DEBUG, ReprMixin +if sys.version_info[:2] >= (3, 11): from typing import overload +else: from typing_extensions import overload class DynResource(t.Protocol): resource_id: t.ClassVar[str] - @classmethod def from_system(cls) -> t.Sequence[t.Any]: ... -# NOTE: We need to do this so that overload can register -# correct overloads to typing registry -if sys.version_info[:2] >= (3, 11): - from typing import overload -else: - from typing_extensions import overload - logger = logging.getLogger(__name__) - def _strtoul(s: str) -> int: """Return -1 or positive integer sequence string starts with,.""" if not s: return -1 diff --git a/src/openllm/_types.py b/src/openllm/_types.py deleted file mode 100644 index 3a0151c2..00000000 --- a/src/openllm/_types.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Types definition for OpenLLM. - -Note that this module SHOULD NOT BE IMPORTED DURING RUNTIME, as this serve only for typing purposes. -It will raises a RuntimeError if this is imported eagerly. -""" -from __future__ import annotations -import typing as t - -if not t.TYPE_CHECKING: raise RuntimeError(f"{__name__} should not be imported during runtime") - -import attr - -import bentoml -from bentoml._internal.types import ModelSignatureDict as ModelSignatureDict - -from ._configuration import ( - AdapterType, - LiteralRuntime as LiteralRuntime, -) - -if t.TYPE_CHECKING: - import peft - - import openllm - from bentoml._internal.runner.runnable import RunnableMethod - from bentoml._internal.runner.runner import RunnerMethod - from bentoml._internal.runner.strategy import Strategy - - from ._llm import ( - M as _M, - T as _T, - ) - from .bundle.oci import LiteralContainerVersionStrategy - from .utils.lazy import VersionInfo - -AnyCallable = t.Callable[..., t.Any] -DictStrAny = dict[str, t.Any] -ListAny = list[t.Any] -ListStr = list[str] -TupleAny = tuple[t.Any, ...] -P = t.ParamSpec("P") -T = t.TypeVar("T") -At = t.TypeVar("At", bound=attr.AttrsInstance) - -class PeftAdapterOutput(t.TypedDict): - success: bool - result: dict[str, peft.PeftConfig] - error_msg: str - -class LLMEmbeddings(t.TypedDict): - embeddings: t.List[t.List[float]] - num_tokens: int - -class AdaptersTuple(TupleAny): - adapter_id: str - name: str | None - config: DictStrAny - -class RefTuple(TupleAny): - git_hash: str - version: VersionInfo - strategy: LiteralContainerVersionStrategy - -AdaptersMapping = dict[AdapterType, tuple[AdaptersTuple, ...]] - -class LLMRunnable(bentoml.Runnable, t.Generic[_M, _T]): - SUPPORTED_RESOURCES = ("amd.com/gpu", "nvidia.com/gpu", "cpu") - SUPPORTS_CPU_MULTI_THREADING = True - __call__: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]] - set_adapter: RunnableMethod[LLMRunnable[_M, _T], [str], dict[t.Literal["success", "error_msg"], bool | str]] - embeddings: RunnableMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings] - generate: RunnableMethod[LLMRunnable[_M, _T], [str], list[t.Any]] - generate_one: RunnableMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]] - generate_iterator: RunnableMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]] - -class LLMRunner(bentoml.Runner, t.Generic[_M, _T]): - __doc__: str - __module__: str - llm_type: str - identifying_params: dict[str, t.Any] - llm: openllm.LLM[_M, _T] - config: openllm.LLMConfig - implementation: LiteralRuntime - supports_embeddings: bool - supports_hf_agent: bool - has_adapters: bool - embeddings: RunnerMethod[LLMRunnable[_M, _T], [list[str]], LLMEmbeddings] - generate: RunnerMethod[LLMRunnable[_M, _T], [str], list[t.Any]] - generate_one: RunnerMethod[LLMRunnable[_M, _T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]] - generate_iterator: RunnerMethod[LLMRunnable[_M, _T], [str], t.Generator[t.Any, None, None]] - def __init__(self, runnable_class: type[LLMRunnable[_M, _T]], *, runnable_init_params: dict[str, t.Any] | None = ..., name: str | None = ..., scheduling_strategy: type[Strategy] = ..., models: list[bentoml.Model] | None = ..., max_batch_size: int | None = ..., max_latency_ms: int | None = ..., method_configs: dict[str, dict[str, int]] | None = ..., embedded: bool = False,) -> None: ... - def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: ... - def embed(self, prompt: str | list[str]) -> LLMEmbeddings: ... - def run(self, prompt: str, **attrs: t.Any) -> t.Any: ... - async def async_run(self, prompt: str, **attrs: t.Any) -> t.Any: ... - def download_model(self) -> bentoml.Model: ... - @property - def peft_adapters(self) -> PeftAdapterOutput: ... - @property - def __repr_keys__(self) -> set[str]: ... diff --git a/src/openllm/_typing_compat.py b/src/openllm/_typing_compat.py new file mode 100644 index 00000000..8fe8d92e --- /dev/null +++ b/src/openllm/_typing_compat.py @@ -0,0 +1,102 @@ +from __future__ import annotations +import sys, typing as t, bentoml, attr, abc +from bentoml._internal.types import ModelSignatureDict as ModelSignatureDict +if t.TYPE_CHECKING: + import openllm, peft, transformers, auto_gptq as autogptq, vllm + from bentoml._internal.runner.runnable import RunnableMethod + from bentoml._internal.runner.runner import RunnerMethod + from bentoml._internal.runner.strategy import Strategy + + from .bundle.oci import LiteralContainerVersionStrategy + from .utils.lazy import VersionInfo + +M = t.TypeVar("M", bound="t.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.LLMEngine, vllm.AsyncLLMEngine, peft.PeftModel, autogptq.modeling.BaseGPTQForCausalLM]") +T = t.TypeVar("T", bound="t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]") + +AnyCallable = t.Callable[..., t.Any] +DictStrAny = t.Dict[str, t.Any] +ListAny = t.List[t.Any] +ListStr = t.List[str] +TupleAny = t.Tuple[t.Any, ...] +At = t.TypeVar("At", bound=attr.AttrsInstance) + +LiteralRuntime = t.Literal["pt", "tf", "flax", "vllm"] +AdapterType = t.Literal["lora", "adalora", "adaption_prompt", "prefix_tuning", "p_tuning", "prompt_tuning", "ia3"] + +if sys.version_info[:2] >= (3,11): + from typing import LiteralString as LiteralString, Self as Self, overload as overload + from typing import NotRequired as NotRequired, Required as Required, dataclass_transform as dataclass_transform +else: + from typing_extensions import LiteralString as LiteralString, Self as Self, overload as overload + from typing_extensions import NotRequired as NotRequired, Required as Required, dataclass_transform as dataclass_transform + +if sys.version_info[:2] >= (3,10): + from typing import TypeAlias as TypeAlias, ParamSpec as ParamSpec, Concatenate as Concatenate +else: + from typing_extensions import TypeAlias as TypeAlias, ParamSpec as ParamSpec, Concatenate as Concatenate + +if sys.version_info[:2] >= (3,9): + from typing import TypedDict as TypedDict +else: + from typing_extensions import TypedDict as TypedDict + +class PeftAdapterOutput(TypedDict): + success: bool + result: t.Dict[str, peft.PeftConfig] + error_msg: str + +class LLMEmbeddings(t.TypedDict): + embeddings: t.List[t.List[float]] + num_tokens: int + +class AdaptersTuple(TupleAny): + adapter_id: str + name: t.Optional[str] + config: DictStrAny + +AdaptersMapping = t.Dict[AdapterType, t.Tuple[AdaptersTuple, ...]] + +class RefTuple(TupleAny): + git_hash: str + version: VersionInfo + strategy: LiteralContainerVersionStrategy + +class LLMRunnable(bentoml.Runnable, t.Generic[M, T]): + SUPPORTED_RESOURCES = ("amd.com/gpu", "nvidia.com/gpu", "cpu") + SUPPORTS_CPU_MULTI_THREADING = True + __call__: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]] + set_adapter: RunnableMethod[LLMRunnable[M, T], [str], dict[t.Literal["success", "error_msg"], bool | str]] + embeddings: RunnableMethod[LLMRunnable[M, T], [list[str]], LLMEmbeddings] + generate: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]] + generate_one: RunnableMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]] + generate_iterator: RunnableMethod[LLMRunnable[M, T], [str], t.Generator[t.Any, None, None]] + +class LLMRunner(bentoml.Runner, t.Generic[M, T]): + __doc__: str + __module__: str + llm_type: str + identifying_params: dict[str, t.Any] + llm: openllm.LLM[M, T] + config: openllm.LLMConfig + implementation: LiteralRuntime + supports_embeddings: bool + supports_hf_agent: bool + has_adapters: bool + embeddings: RunnerMethod[LLMRunnable[M, T], [list[str]], LLMEmbeddings] + generate: RunnerMethod[LLMRunnable[M, T], [str], list[t.Any]] + generate_one: RunnerMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]] + generate_iterator: RunnerMethod[LLMRunnable[M, T], [str], t.Generator[t.Any, None, None]] + def __init__(self, runnable_class: type[LLMRunnable[M, T]], *, runnable_init_params: dict[str, t.Any] | None = ..., name: str | None = ..., scheduling_strategy: type[Strategy] = ..., models: list[bentoml.Model] | None = ..., max_batch_size: int | None = ..., max_latency_ms: int | None = ..., method_configs: dict[str, dict[str, int]] | None = ..., embedded: bool = False,) -> None: ... + def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: ... + @abc.abstractmethod + def embed(self, prompt: str | list[str]) -> LLMEmbeddings: ... + def run(self, prompt: str, **attrs: t.Any) -> t.Any: ... + async def async_run(self, prompt: str, **attrs: t.Any) -> t.Any: ... + @abc.abstractmethod + def download_model(self) -> bentoml.Model: ... + @property + @abc.abstractmethod + def peft_adapters(self) -> PeftAdapterOutput: ... + @property + @abc.abstractmethod + def __repr_keys__(self) -> set[str]: ... diff --git a/src/openllm/bundle/__init__.py b/src/openllm/bundle/__init__.py index 51f70ce8..dc276cd4 100644 --- a/src/openllm/bundle/__init__.py +++ b/src/openllm/bundle/__init__.py @@ -3,10 +3,8 @@ These utilities will stay internal, and its API can be changed or updated without backward-compatibility. """ from __future__ import annotations -import sys -import typing as t - -import openllm +import os, typing as t +from openllm.utils import LazyModule _import_structure: dict[str, list[str]] = {"_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"], "oci": ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries", "RefResolver"]} @@ -29,5 +27,8 @@ if t.TYPE_CHECKING: get_base_container_tag as get_base_container_tag, supported_registries as supported_registries, ) -else: - sys.modules[__name__] = openllm.utils.LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + +__lazy=LazyModule(__name__, os.path.abspath("__file__"), _import_structure) +__all__=__lazy.__all__ +__dir__=__lazy.__dir__ +__getattr__=__lazy.__getattr__ diff --git a/src/openllm/bundle/_package.py b/src/openllm/bundle/_package.py index cb316bef..71deb419 100644 --- a/src/openllm/bundle/_package.py +++ b/src/openllm/bundle/_package.py @@ -1,30 +1,18 @@ +# mypy: disable-error-code="misc" from __future__ import annotations -import importlib.metadata -import inspect -import logging -import os -import typing as t +import importlib.metadata, inspect, logging, os, typing as t from pathlib import Path - -import fs -import fs.copy -import fs.errors -import orjson +import fs, fs.copy, fs.errors, orjson, bentoml, openllm from simple_di import Provide, inject - -import bentoml -import openllm from bentoml._internal.bento.build_config import BentoBuildConfig, DockerOptions, ModelSpec, PythonOptions from bentoml._internal.configuration.containers import BentoMLContainer - from . import oci if t.TYPE_CHECKING: from fs.base import FS - + from openllm._typing_compat import LiteralString from bentoml._internal.bento import BentoStore from bentoml._internal.models.model import ModelStore - from .oci import LiteralContainerRegistry, LiteralContainerVersionStrategy logger = logging.getLogger(__name__) @@ -39,7 +27,7 @@ def build_editable(path: str) -> str | None: from build.env import IsolatedEnvBuilder module_location = openllm.utils.pkg.source_locations("openllm") if not module_location: raise RuntimeError("Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.") - pyproject_path = Path(module_location).parent.parent / "pyproject.toml" + pyproject_path = Path(module_location).parent.parent/"pyproject.toml" if os.path.isfile(pyproject_path.__fspath__()): logger.info("OpenLLM is installed in editable mode. Generating built wheels...") with IsolatedEnvBuilder() as env: @@ -80,31 +68,26 @@ def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_d _tf_version = importlib.metadata.version(candidate) packages.extend([f"tensorflow>={_tf_version}"]) break - except importlib.metadata.PackageNotFoundError: - pass + except importlib.metadata.PackageNotFoundError: pass # noqa: PERF203 # Ok to ignore here since we actually need to check for all possible tensorflow distribution. else: if not openllm.utils.is_torch_available(): raise ValueError("PyTorch is not available. Make sure to have it locally installed.") packages.extend([f'torch>={importlib.metadata.version("torch")}']) - wheels: list[str] = [] built_wheels = build_editable(llm_fs.getsyspath("/")) if built_wheels is not None: wheels.append(llm_fs.getsyspath(f"/{built_wheels.split('/')[-1]}")) return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=["https://download.pytorch.org/whl/cu118"]) -def construct_docker_options( - llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_resource: int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, adapter_map: dict[str, str | None] | None, dockerfile_template: str | None, runtime: t.Literal["ggml", "transformers"], serialisation_format: t.Literal["safetensors", "legacy"], container_registry: LiteralContainerRegistry, - container_version_strategy: LiteralContainerVersionStrategy, -) -> DockerOptions: +def construct_docker_options(llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_resource: int | float, quantize: LiteralString | None, bettertransformer: bool | None, adapter_map: dict[str, str | None] | None, dockerfile_template: str | None, runtime: t.Literal["ggml", "transformers"], serialisation_format: t.Literal["safetensors", "legacy"], container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy) -> DockerOptions: _bentoml_config_options = os.environ.pop("BENTOML_CONFIG_OPTIONS", "") - _bentoml_config_options_opts = [ - "tracing.sample_rate=1.0", f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'api_server.traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".workers_per_resource={workers_per_resource}', - ] + _bentoml_config_options_opts = ["tracing.sample_rate=1.0", f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'api_server.traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".traffic.timeout={llm.config["timeout"]}', f'runners."llm-{llm.config["start_name"]}-runner".workers_per_resource={workers_per_resource}'] _bentoml_config_options += " " if _bentoml_config_options else "" + " ".join(_bentoml_config_options_opts) env: openllm.utils.EnvVarMixin = llm.config["env"] if env["framework_value"] == "vllm": serialisation_format = "legacy" env_dict = { - env.framework: env["framework_value"], env.config: f"'{llm.config.model_dump_json().decode()}'", "OPENLLM_MODEL": llm.config["model_name"], "OPENLLM_SERIALIZATION": serialisation_format, "OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'", "BENTOML_DEBUG": str(True), "BENTOML_QUIET": str(False), "BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'", - env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}" + env.framework: env["framework_value"], env.config: f"'{llm.config.model_dump_json().decode()}'", + env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}", + "OPENLLM_MODEL": llm.config["model_name"], "OPENLLM_SERIALIZATION": serialisation_format, + "OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'", "BENTOML_DEBUG": str(True), "BENTOML_QUIET": str(False), "BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'", } if adapter_map: env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1") @@ -117,10 +100,9 @@ def construct_docker_options( return DockerOptions(base_image=f"{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}", env=env_dict, dockerfile_template=dockerfile_template) @inject -def create_bento( - bento_tag: bentoml.Tag, llm_fs: FS, llm: openllm.LLM[t.Any, t.Any], workers_per_resource: str | int | float, quantize: t.LiteralString | None, bettertransformer: bool | None, device: tuple[str, ...] | None, dockerfile_template: str | None, adapter_map: dict[str, str | None] | None = None, extra_dependencies: tuple[str, ...] | None = None, runtime: t.Literal[ - "ggml", "transformers"] = "transformers", serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors", container_registry: LiteralContainerRegistry = "ecr", container_version_strategy: LiteralContainerVersionStrategy = "release", _bento_store: BentoStore = Provide[BentoMLContainer.bento_store], _model_store: ModelStore = Provide[BentoMLContainer.model_store], -) -> bentoml.Bento: +def create_bento(bento_tag: bentoml.Tag, llm_fs: FS, llm: openllm.LLM[t.Any, t.Any], workers_per_resource: str | int | float, quantize: LiteralString | None, bettertransformer: bool | None, dockerfile_template: str | None, adapter_map: dict[str, str | None] | None = None, extra_dependencies: tuple[str, ...] | None = None, + runtime: t.Literal[ "ggml", "transformers"] = "transformers", serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors", container_registry: LiteralContainerRegistry = "ecr", container_version_strategy: LiteralContainerVersionStrategy = "release", + _bento_store: BentoStore = Provide[BentoMLContainer.bento_store], _model_store: ModelStore = Provide[BentoMLContainer.model_store]) -> bentoml.Bento: framework_envvar = llm.config["env"]["framework_value"] labels = dict(llm.identifying_params) labels.update({"_type": llm.llm_type, "_framework": framework_envvar, "start_name": llm.config["start_name"], "base_name_or_path": llm.model_id, "bundler": "openllm.bundle"}) @@ -129,13 +111,9 @@ def create_bento( if workers_per_resource == "round_robin": workers_per_resource = 1.0 elif workers_per_resource == "conserved": workers_per_resource = 1.0 if openllm.utils.device_count() == 0 else float(1 / openllm.utils.device_count()) else: - try: - workers_per_resource = float(workers_per_resource) - except ValueError: - raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None - elif isinstance(workers_per_resource, int): - workers_per_resource = float(workers_per_resource) - + try: workers_per_resource = float(workers_per_resource) + except ValueError: raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None + elif isinstance(workers_per_resource, int): workers_per_resource = float(workers_per_resource) logger.info("Building Bento for '%s'", llm.config["start_name"]) # add service.py definition to this temporary folder openllm.utils.codegen.write_service(llm, adapter_map, llm_fs) diff --git a/src/openllm/bundle/oci/__init__.py b/src/openllm/bundle/oci/__init__.py index 30489fc1..fda395f5 100644 --- a/src/openllm/bundle/oci/__init__.py +++ b/src/openllm/bundle/oci/__init__.py @@ -1,31 +1,21 @@ +# mypy: disable-error-code="misc" """OCI-related utilities for OpenLLM. This module is considered to be internal and API are subjected to change.""" from __future__ import annotations -import functools -import importlib -import logging -import pathlib -import shutil -import subprocess -import typing as t +import functools, importlib, logging, os, pathlib, shutil, subprocess, typing as t from datetime import datetime, timedelta, timezone - -import attr -import orjson - -import bentoml -import openllm +import attr, orjson, bentoml, openllm +from openllm.utils.lazy import VersionInfo if t.TYPE_CHECKING: from ghapi import all + from openllm._typing_compat import RefTuple - from openllm._types import DictStrAny, RefTuple -else: - all = openllm.utils.LazyLoader("all", globals(), "ghapi.all") +all = openllm.utils.LazyLoader("all", globals(), "ghapi.all") # noqa: F811 logger = logging.getLogger(__name__) _BUILDER = bentoml.container.get_backend("buildx") -ROOT_DIR = pathlib.Path(__file__).parent.parent.parent +ROOT_DIR = pathlib.Path(os.path.abspath("__file__")).parent.parent.parent # TODO: support quay LiteralContainerRegistry = t.Literal["docker", "gh", "ecr"] @@ -45,14 +35,10 @@ _module_location = openllm.utils.pkg.source_locations("openllm") @functools.lru_cache @openllm.utils.apply(str.lower) -def get_base_container_name(reg: LiteralContainerRegistry) -> str: - return _CONTAINER_REGISTRY[reg] +def get_base_container_name(reg: LiteralContainerRegistry) -> str: return _CONTAINER_REGISTRY[reg] -def _convert_version_from_string(s: str) -> openllm.utils.VersionInfo: - return openllm.utils.VersionInfo.from_version_string(s) - -def _commit_time_range(r: int = 5) -> str: - return (datetime.now(timezone.utc) - timedelta(days=r)).strftime("%Y-%m-%dT%H:%M:%SZ") +def _convert_version_from_string(s: str) -> VersionInfo: return VersionInfo.from_version_string(s) +def _commit_time_range(r: int = 5) -> str: return (datetime.now(timezone.utc) - timedelta(days=r)).strftime("%Y-%m-%dT%H:%M:%SZ") class VersionNotSupported(openllm.exceptions.OpenLLMException): """Raised when the stable release is too low that it doesn't include OpenLLM base container.""" @@ -66,54 +52,42 @@ def nightly_resolver(cls: type[RefResolver]) -> str: docker_bin = shutil.which("docker") if docker_bin is None: logger.warning("To get the correct available nightly container, make sure to have docker available. Fallback to previous behaviour for determine nightly hash (container might not exists due to the lack of GPU machine at a time. See https://github.com/bentoml/OpenLLM/pkgs/container/openllm for available image.)") - commits = t.cast("list[DictStrAny]", cls._ghapi.repos.list_commits(since=_commit_time_range())) + commits = t.cast("list[dict[str, t.Any]]", cls._ghapi.repos.list_commits(since=_commit_time_range())) return next(f'sha-{it["sha"][:7]}' for it in commits if "[skip ci]" not in it["commit"]["message"]) # now is the correct behaviour return orjson.loads(subprocess.check_output([docker_bin, "run", "--rm", "-it", "quay.io/skopeo/stable:latest", "list-tags", "docker://ghcr.io/bentoml/openllm"]).decode().strip())["Tags"][-2] @attr.attrs(eq=False, order=False, slots=True, frozen=True) class RefResolver: - """TODO: Support offline mode. - - Maybe we need to save git hash when building the Bento. - """ git_hash: str = attr.field() version: openllm.utils.VersionInfo = attr.field(converter=_convert_version_from_string) strategy: LiteralContainerVersionStrategy = attr.field() - _ghapi: all.GhApi = all.GhApi(owner=_OWNER, repo=_REPO) # TODO: support offline mode - + _ghapi: t.ClassVar[all.GhApi] = all.GhApi(owner=_OWNER, repo=_REPO) @classmethod - def _nightly_ref(cls) -> RefTuple: - return _RefTuple((nightly_resolver(cls), "refs/heads/main", "nightly")) - + def _nightly_ref(cls) -> RefTuple: return _RefTuple((nightly_resolver(cls), "refs/heads/main", "nightly")) @classmethod def _release_ref(cls, version_str: str | None = None) -> RefTuple: _use_base_strategy = version_str is None if version_str is None: # NOTE: This strategy will only support openllm>0.2.12 - meta: DictStrAny = cls._ghapi.repos.get_latest_release() + meta: dict[str, t.Any] = cls._ghapi.repos.get_latest_release() version_str = meta["name"].lstrip("v") version: tuple[str, str | None] = (cls._ghapi.git.get_ref(ref=f"tags/{meta['name']}")["object"]["sha"], version_str) - else: - version = ("", version_str) + else: version = ("", version_str) if openllm.utils.VersionInfo.from_version_string(t.cast(str, version_str)) < (0, 2, 12): raise VersionNotSupported(f"Version {version_str} doesn't support OpenLLM base container. Consider using 'nightly' or upgrade 'openllm>=0.2.12'") return _RefTuple((*version, "release" if _use_base_strategy else "custom")) - @classmethod @functools.lru_cache(maxsize=64) def from_strategy(cls, strategy_or_version: t.Literal["release", "nightly"] | str | None = None) -> RefResolver: - if strategy_or_version is None or strategy_or_version == "release": - logger.debug("Using default strategy 'release' for resolving base image version.") - return cls(*cls._release_ref()) - elif strategy_or_version == "latest": - return cls("latest", "0.0.0", "latest") + # using default strategy + if strategy_or_version is None or strategy_or_version == "release": return cls(*cls._release_ref()) + elif strategy_or_version == "latest": return cls("latest", "0.0.0", "latest") elif strategy_or_version == "nightly": _ref = cls._nightly_ref() return cls(_ref[0], "0.0.0", _ref[-1]) else: logger.warning("Using custom %s. Make sure that it is at lease 0.2.12 for base container support.", strategy_or_version) return cls(*cls._release_ref(version_str=strategy_or_version)) - @property def tag(self) -> str: # NOTE: latest tag can also be nightly, but discouraged to use it. For nightly refer to use sha- @@ -122,33 +96,24 @@ class RefResolver: else: return repr(self.version) @functools.lru_cache(maxsize=256) -def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str: - return RefResolver.from_strategy(strategy).tag - +def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str: return RefResolver.from_strategy(strategy).tag def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None, version_strategy: LiteralContainerVersionStrategy = "release", push: bool = False, machine: bool = False) -> dict[str | LiteralContainerRegistry, str]: - """This is a utility function for building base container for OpenLLM. It will build the base container for all registries if ``None`` is passed. - - Note that this is useful for debugging or for any users who wish to integrate vertically with OpenLLM. For most users, you should be able to get the image either from GitHub Container Registry or our public ECR registry. - """ try: if not _BUILDER.health(): raise openllm.exceptions.Error - except (openllm.exceptions.Error, subprocess.CalledProcessError): - raise RuntimeError("Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.") from None + except (openllm.exceptions.Error, subprocess.CalledProcessError): raise RuntimeError("Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.") from None if openllm.utils.device_count() == 0: raise RuntimeError("Building base container requires GPUs (None available)") if not shutil.which("nvidia-container-runtime"): raise RuntimeError("NVIDIA Container Toolkit is required to compile CUDA kernel in container.") if not _module_location: raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)") pyproject_path = pathlib.Path(_module_location).parent.parent / "pyproject.toml" if not pyproject_path.exists(): raise ValueError("This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'") - if t.TYPE_CHECKING: tags: dict[str | LiteralContainerRegistry, str] - if not registries: tags = {alias: f"{value}:{get_base_container_tag(version_strategy)}" for alias, value in _CONTAINER_REGISTRY.items()} # default to all registries with latest tag strategy + if not registries: tags: dict[str | LiteralContainerRegistry, str] = {alias: f"{value}:{get_base_container_tag(version_strategy)}" for alias, value in _CONTAINER_REGISTRY.items()} # default to all registries with latest tag strategy else: registries = [registries] if isinstance(registries, str) else list(registries) tags = {name: f"{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}" for name in registries} try: outputs = _BUILDER.build(file=pathlib.Path(__file__).parent.joinpath("Dockerfile").resolve().__fspath__(), context_path=pyproject_path.parent.__fspath__(), tag=tuple(tags.values()), push=push, progress="plain" if openllm.utils.get_debug_mode() else "auto", quiet=machine) if machine and outputs is not None: tags["image_sha"] = outputs.decode("utf-8").strip() - except Exception as err: - raise openllm.exceptions.OpenLLMException(f"Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}") from err + except Exception as err: raise openllm.exceptions.OpenLLMException(f"Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}") from err return tags if t.TYPE_CHECKING: @@ -156,10 +121,7 @@ if t.TYPE_CHECKING: supported_registries: list[str] __all__ = ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries", "RefResolver"] - -def __dir__() -> list[str]: - return sorted(__all__) - +def __dir__() -> list[str]: return sorted(__all__) def __getattr__(name: str) -> t.Any: if name == "supported_registries": return functools.lru_cache(1)(lambda: list(_CONTAINER_REGISTRY))() elif name == "CONTAINER_NAMES": return _CONTAINER_REGISTRY diff --git a/src/openllm/cli/_factory.py b/src/openllm/cli/_factory.py index 8b075d62..21e1d8a9 100644 --- a/src/openllm/cli/_factory.py +++ b/src/openllm/cli/_factory.py @@ -1,37 +1,23 @@ from __future__ import annotations -import functools -import importlib.util -import os -import typing as t - -import click -import click_option_group as cog -import inflection -import orjson +import functools, importlib.util, os, typing as t +import click, click_option_group as cog, inflection, orjson, bentoml, openllm from bentoml_cli.utils import BentoMLCommandGroup from click.shell_completion import CompletionItem - -import bentoml -import openllm from bentoml._internal.configuration.containers import BentoMLContainer - +from openllm._typing_compat import LiteralString, DictStrAny, ParamSpec, Concatenate from . import termui if t.TYPE_CHECKING: import subprocess + from openllm._configuration import LLMConfig - from .._configuration import LLMConfig - from .._types import DictStrAny, P - TupleStr = tuple[str, ...] -else: - TupleStr = tuple - +P = ParamSpec("P") LiteralOutput = t.Literal["json", "pretty", "porcelain"] _AnyCallable = t.Callable[..., t.Any] FC = t.TypeVar("FC", bound=t.Union[_AnyCallable, click.Command]) -def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: tuple[str, ...] | None, environ: DictStrAny,) -> DictStrAny: +def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: t.Tuple[str, ...] | None, environ: DictStrAny,) -> DictStrAny: # TODO: Support amd.com/gpu on k8s _bentoml_config_options_env = environ.pop("BENTOML_CONFIG_OPTIONS", "") _bentoml_config_options_opts = ["tracing.sample_rate=1.0", f"api_server.traffic.timeout={server_timeout}", f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}', f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}'] @@ -44,17 +30,15 @@ def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_res _adapter_mapping_key = "adapter_map" -def _id_callback(ctx: click.Context, _: click.Parameter, value: tuple[str, ...] | None) -> None: +def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...] | None) -> None: if not value: return None if _adapter_mapping_key not in ctx.params: ctx.params[_adapter_mapping_key] = {} for v in value: adapter_id, *adapter_name = v.rsplit(":", maxsplit=1) # try to resolve the full path if users pass in relative, # currently only support one level of resolve path with current directory - try: - adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd()) - except FileNotFoundError: - pass + try: adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd()) + except FileNotFoundError: pass ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None return None @@ -104,7 +88,7 @@ Available official model_id(s): [default: {llm_config['default_id']}] @start_decorator(llm_config, serve_grpc=_serve_grpc) @click.pass_context def start_cmd( - ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None, workers_per_resource: t.Literal["conserved", "round_robin"] | t.LiteralString, device: tuple[str, ...], quantize: t.Literal["int8", "int4", "gptq"] | None, bettertransformer: bool | None, runtime: t.Literal["ggml", "transformers"], fast: bool, + ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None, workers_per_resource: t.Literal["conserved", "round_robin"] | LiteralString, device: t.Tuple[str, ...], quantize: t.Literal["int8", "int4", "gptq"] | None, bettertransformer: bool | None, runtime: t.Literal["ggml", "transformers"], fast: bool, serialisation_format: t.Literal["safetensors", "legacy"], adapter_id: str | None, return_process: bool, **attrs: t.Any, ) -> LLMConfig | subprocess.Popen[bytes]: fast = str(fast).upper() in openllm.utils.ENV_VARS_TRUE_VALUES @@ -152,7 +136,7 @@ Available official model_id(s): [default: {llm_config['default_id']}] llm = openllm.utils.infer_auto_class(env["framework_value"]).for_model(model, model_id=start_env[env.model_id], model_version=model_version, llm_config=config, ensure_available=not fast, adapter_map=adapter_map, serialisation=serialisation_format) start_env.update({env.config: llm.config.model_dump_json().decode()}) - server = bentoml.GrpcServer("_service.py:svc", **server_attrs) if _serve_grpc else bentoml.HTTPServer("_service.py:svc", **server_attrs) + server = bentoml.GrpcServer("_service:svc", **server_attrs) if _serve_grpc else bentoml.HTTPServer("_service:svc", **server_attrs) openllm.utils.analytics.track_start_init(llm.config) def next_step(model_name: str, adapter_map: DictStrAny | None) -> None: @@ -192,7 +176,7 @@ def noop_command(group: click.Group, llm_config: LLMConfig, _serve_grpc: bool, * return noop -def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: t.LiteralString | None, adapter_map: dict[str, str | None] | None, num_workers: int) -> None: +def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: LiteralString | None, adapter_map: dict[str, str | None] | None, num_workers: int) -> None: if adapter_map and not openllm.utils.is_peft_available(): ctx.fail("Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'") if quantize and llm_config.default_implementation() == "vllm": ctx.fail(f"Quantization is not yet supported with vLLM. Set '{llm_config['env']['framework']}=\"pt\"' to run with quantization.") requirements = llm_config["requirements"] @@ -201,8 +185,8 @@ def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: t.Li if len(missing_requirements) > 0: termui.echo(f"Make sure to have the following dependencies available: {missing_requirements}", fg="yellow") def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]: - return lambda fn: openllm.utils.compose( - *[ + def wrapper(fn: FC) -> t.Callable[[FC], FC]: + composed = openllm.utils.compose( llm_config.to_click_options, _http_server_args if not serve_grpc else _grpc_server_args, cog.optgroup.group("General LLM Options", help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."), model_id_option(factory=cog.optgroup, model_env=llm_config["env"]), @@ -246,13 +230,14 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab ), cog.optgroup.option("--adapter-id", default=None, help="Optional name or path for given LoRA adapter" + f" to wrap '{llm_config['model_name']}'", multiple=True, callback=_id_callback, metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]"), click.option("--return-process", is_flag=True, default=False, help="Internal use only.", hidden=True), - ] - )(fn) + ) + return composed(fn) + return wrapper -def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> TupleStr | None: +def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> t.Tuple[str, ...] | None: if value is None: return value - if not openllm.utils.LazyType(TupleStr).isinstance(value): ctx.fail(f"{param} only accept multiple values, not {type(value)} (value: {value})") - el: TupleStr = tuple(i for k in value for i in k) + if not isinstance(value, tuple): ctx.fail(f"{param} only accept multiple values, not {type(value)} (value: {value})") + el: t.Tuple[str, ...] = tuple(i for k in value for i in k) # NOTE: --device all is a special case if len(el) == 1 and el[0] == "all": return tuple(map(str, openllm.utils.available_devices())) return el @@ -269,7 +254,7 @@ def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig] command = "serve" if not serve_grpc else "serve-grpc" group = cog.optgroup.group(f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options", help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",) - def decorator(f: t.Callable[t.Concatenate[int, str | None, P], LLMConfig]) -> t.Callable[[FC], FC]: + def decorator(f: t.Callable[Concatenate[int, t.Optional[str], P], LLMConfig]) -> t.Callable[[FC], FC]: serve_command = cli.commands[command] # The first variable is the argument bento # The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS @@ -285,7 +270,6 @@ def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig] param_decls = (*attrs.pop("opts"), *attrs.pop("secondary_opts")) f = cog.optgroup.option(*param_decls, **attrs)(f) return group(f) - return decorator _http_server_args, _grpc_server_args = parse_serve_args(False), parse_serve_args(True) @@ -299,12 +283,10 @@ def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | factory = attrs.pop("factory", click) factory_attr = attrs.pop("attr", "option") if factory_attr != "argument": attrs.setdefault("help", "General option for OpenLLM CLI.") - def decorator(f: FC | None) -> FC: callback = getattr(factory, factory_attr, None) if callback is None: raise ValueError(f"Factory {factory} has no attribute {factory_attr}.") return t.cast(FC, callback(*param_decls, **attrs)(f) if f is not None else callback(*param_decls, **attrs)) - return decorator cli_option = functools.partial(_click_factory_type, attr="option") @@ -312,12 +294,8 @@ cli_argument = functools.partial(_click_factory_type, attr="argument") def output_option(f: _AnyCallable | None = None, *, default_value: LiteralOutput = "pretty", **attrs: t.Any) -> t.Callable[[FC], FC]: output = ["json", "pretty", "porcelain"] - - def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]: - return [CompletionItem(it) for it in output] - + def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]: return [CompletionItem(it) for it in output] return cli_option("-o", "--output", "output", type=click.Choice(output), default=default_value, help="Showing output type.", show_default=True, envvar="OPENLLM_OUTPUT", show_envvar=True, shell_complete=complete_output_var, **attrs)(f) - def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option( "--fast/--no-fast", show_default=True, default=False, envvar="OPENLLM_USE_LOCAL_LATEST", show_envvar=True, help="""Whether to skip checking if models is already in store. @@ -325,18 +303,10 @@ def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC This is useful if you already downloaded or setup the model beforehand. """, **attrs )(f) - -def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: - return cli_option("--machine", is_flag=True, default=False, hidden=True, **attrs)(f) - -def model_id_option(f: _AnyCallable | None = None, *, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: - return cli_option("--model-id", type=click.STRING, default=None, envvar=model_env.model_id if model_env is not None else None, show_envvar=model_env is not None, help="Optional model_id name or path for (fine-tune) weight.", **attrs)(f) - -def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: - return cli_option("--model-version", type=click.STRING, default=None, help="Optional model version to save for this model. It will be inferred automatically from model-id.", **attrs)(f) - -def model_name_argument(f: _AnyCallable | None = None, required: bool = True) -> t.Callable[[FC], FC]: - return cli_argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required)(f) +def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--machine", is_flag=True, default=False, hidden=True, **attrs)(f) +def model_id_option(f: _AnyCallable | None = None, *, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--model-id", type=click.STRING, default=None, envvar=model_env.model_id if model_env is not None else None, show_envvar=model_env is not None, help="Optional model_id name or path for (fine-tune) weight.", **attrs)(f) +def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option("--model-version", type=click.STRING, default=None, help="Optional model version to save for this model. It will be inferred automatically from model-id.", **attrs)(f) +def model_name_argument(f: _AnyCallable | None = None, required: bool = True) -> t.Callable[[FC], FC]: return cli_argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required)(f) def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option( @@ -423,10 +393,8 @@ def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, va value = inflection.underscore(value) if value in _wpr_strategies: return value else: - try: - float(value) # type: ignore[arg-type] - except ValueError: - raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None + try: float(value) # type: ignore[arg-type] + except ValueError: raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None else: return value diff --git a/src/openllm/cli/_sdk.py b/src/openllm/cli/_sdk.py index f41de7ae..27ab8069 100644 --- a/src/openllm/cli/_sdk.py +++ b/src/openllm/cli/_sdk.py @@ -1,34 +1,21 @@ - from __future__ import annotations -import itertools -import logging -import os -import re -import subprocess -import sys -import typing as t - +import itertools, logging, os, re, subprocess, sys, typing as t +import bentoml, openllm from simple_di import Provide, inject - -import bentoml -import openllm from bentoml._internal.configuration.containers import BentoMLContainer from openllm.exceptions import OpenLLMException - from . import termui from ._factory import start_command_factory if t.TYPE_CHECKING: + from openllm._typing_compat import LiteralString, LiteralRuntime from bentoml._internal.bento import BentoStore - from openllm._configuration import LiteralRuntime, LLMConfig + from openllm._configuration import LLMConfig from openllm.bundle.oci import LiteralContainerRegistry, LiteralContainerVersionStrategy logger = logging.getLogger(__name__) -def _start( - model_name: str, /, *, model_id: str | None = None, timeout: int = 30, workers_per_resource: t.Literal["conserved", "round_robin"] | float | None = None, device: tuple[str, ...] | t.Literal["all"] | None = None, quantize: t.Literal["int8", "int4", "gptq"] | None = None, bettertransformer: bool | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", - fast: bool = False, adapter_map: dict[t.LiteralString, str | None] | None = None, framework: LiteralRuntime | None = None, additional_args: list[str] | None = None, _serve_grpc: bool = False, __test__: bool = False, **_: t.Any, -) -> LLMConfig | subprocess.Popen[bytes]: +def _start(model_name: str, /, *, model_id: str | None = None, timeout: int = 30, workers_per_resource: t.Literal["conserved", "round_robin"] | float | None = None, device: tuple[str, ...] | t.Literal["all"] | None = None, quantize: t.Literal["int8", "int4", "gptq"] | None = None, bettertransformer: bool | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", fast: bool = False, adapter_map: dict[LiteralString, str | None] | None = None, framework: LiteralRuntime | None = None, additional_args: list[str] | None = None, _serve_grpc: bool = False, __test__: bool = False, **_: t.Any) -> LLMConfig | subprocess.Popen[bytes]: """Python API to start a LLM server. These provides one-to-one mapping to CLI arguments. For all additional arguments, pass it as string to ``additional_args``. For example, if you want to @@ -169,9 +156,7 @@ def _build(model_name: str, /, *, model_id: str | None = None, model_version: st if matched is None: raise ValueError(f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.") return bentoml.get(matched.group(1), _bento_store=bento_store) -def _import_model( - model_name: str, /, *, model_id: str | None = None, model_version: str | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", implementation: LiteralRuntime = "pt", quantize: t.Literal["int8", "int4", "gptq"] | None = None, serialisation_format: t.Literal["legacy", "safetensors"] = "safetensors", additional_args: t.Sequence[str] | None = None, -) -> bentoml.Model: +def _import_model(model_name: str, /, *, model_id: str | None = None, model_version: str | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", implementation: LiteralRuntime = "pt", quantize: t.Literal["int8", "int4", "gptq"] | None = None, serialisation_format: t.Literal["legacy", "safetensors"] = "safetensors", additional_args: t.Sequence[str] | None = None) -> bentoml.Model: """Import a LLM into local store. > [!NOTE] diff --git a/src/openllm/cli/entrypoint.py b/src/openllm/cli/entrypoint.py index b9c0b0a3..5091cd47 100644 --- a/src/openllm/cli/entrypoint.py +++ b/src/openllm/cli/entrypoint.py @@ -20,36 +20,12 @@ bentomodel = openllm.import_model("falcon", model_id='tiiuae/falcon-7b-instruct' ``` """ from __future__ import annotations -import functools -import http.client -import inspect -import itertools -import logging -import os -import platform -import re -import subprocess -import sys -import time -import traceback -import typing as t - -import attr -import click -import click_option_group as cog -import fs -import fs.copy -import fs.errors -import inflection -import orjson +import functools, http.client, inspect, itertools, logging, os, platform, re, subprocess, sys, time, traceback, typing as t +import attr, click, click_option_group as cog, fs, fs.copy, fs.errors, inflection, orjson, bentoml, openllm from bentoml_cli.utils import BentoMLCommandGroup, opt_callback from simple_di import Provide, inject - -import bentoml -import openllm from bentoml._internal.configuration.containers import BentoMLContainer from bentoml._internal.models.model import ModelStore - from . import termui from ._factory import ( FC, @@ -69,9 +45,9 @@ from ._factory import ( start_command_factory, workers_per_resource_option, ) -from .. import bundle, serialisation -from ..exceptions import OpenLLMException -from ..models.auto import ( +from openllm import bundle, serialisation +from openllm.exceptions import OpenLLMException +from openllm.models.auto import ( CONFIG_MAPPING, MODEL_FLAX_MAPPING_NAMES, MODEL_MAPPING_NAMES, @@ -80,7 +56,8 @@ from ..models.auto import ( AutoConfig, AutoLLM, ) -from ..utils import ( +from openllm._typing_compat import DictStrAny, ParamSpec, Concatenate, LiteralString, Self, LiteralRuntime +from openllm.utils import ( DEBUG, DEBUG_ENV_VAR, OPTIONAL_DEPENDENCIES, @@ -105,19 +82,15 @@ from ..utils import ( if t.TYPE_CHECKING: import torch - from bentoml._internal.bento import BentoStore from bentoml._internal.container import DefaultBuilder from openllm.client import BaseClient + from openllm._schema import EmbeddingsOutput + from openllm.bundle.oci import LiteralContainerRegistry, LiteralContainerVersionStrategy +else: torch = LazyLoader("torch", globals(), "torch") - from .._schema import EmbeddingsOutput - from .._types import DictStrAny, LiteralRuntime, P - from ..bundle.oci import LiteralContainerRegistry, LiteralContainerVersionStrategy -else: - torch, jupytext, nbformat = LazyLoader("torch", globals(), "torch"), LazyLoader("jupytext", globals(), "jupytext"), LazyLoader("nbformat", globals(), "nbformat") - +P = ParamSpec("P") logger = logging.getLogger(__name__) - OPENLLM_FIGLET = """\ ██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗ ██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║ @@ -132,9 +105,7 @@ ServeCommand = t.Literal["serve", "serve-grpc"] @attr.define class GlobalOptions: cloud_context: str | None = attr.field(default=None) - - def with_options(self, **attrs: t.Any) -> t.Self: - return attr.evolve(self, **attrs) + def with_options(self, **attrs: t.Any) -> Self: return attr.evolve(self, **attrs) GrpType = t.TypeVar("GrpType", bound=click.Group) @@ -173,7 +144,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): return wrapper @staticmethod - def usage_tracking(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[t.Concatenate[bool, P], t.Any]: + def usage_tracking(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[Concatenate[bool, P], t.Any]: command_name = attrs.get("name", func.__name__) @functools.wraps(func) @@ -197,7 +168,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): event.return_code = 2 if isinstance(e, KeyboardInterrupt) else 1 analytics.track(event) raise - return t.cast("t.Callable[t.Concatenate[bool, P], t.Any]", wrapper) + return t.cast(t.Callable[Concatenate[bool, P], t.Any], wrapper) @staticmethod def exception_handling(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[P, t.Any]: @@ -266,20 +237,17 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): for subcommand in self.list_commands(ctx): cmd = self.get_command(ctx, subcommand) if cmd is None or cmd.hidden: continue - if subcommand in _cached_extensions: - extensions.append((subcommand, cmd)) - else: - commands.append((subcommand, cmd)) + if subcommand in _cached_extensions: extensions.append((subcommand, cmd)) + else: commands.append((subcommand, cmd)) # allow for 3 times the default spacing if len(commands): limit = formatter.width - 6 - max(len(cmd[0]) for cmd in commands) - rows = [] + rows: list[tuple[str, str]]= [] for subcommand, cmd in commands: help = cmd.get_short_help_str(limit) rows.append((subcommand, help)) if rows: - with formatter.section(_("Commands")): - formatter.write_dl(rows) + with formatter.section(_("Commands")): formatter.write_dl(rows) if len(extensions): limit = formatter.width - 6 - max(len(cmd[0]) for cmd in extensions) rows = [] @@ -287,8 +255,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): help = cmd.get_short_help_str(limit) rows.append((inflection.dasherize(subcommand), help)) if rows: - with formatter.section(_("Extensions")): - formatter.write_dl(rows) + with formatter.section(_("Extensions")): formatter.write_dl(rows) @click.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name="openllm") @click.version_option(None, "--version", "-v", message=f"%(prog)s, %(version)s (compiled: {'yes' if openllm.COMPILED else 'no'})\nPython ({platform.python_implementation()}) {platform.python_version()}") @@ -524,10 +491,11 @@ def build_command( _previously_built = True except bentoml.exceptions.NotFound: bento = bundle.create_bento( - bento_tag, llm_fs, llm, workers_per_resource=workers_per_resource, device=device, adapter_map=adapter_map, quantize=quantize, bettertransformer=bettertransformer, extra_dependencies=enable_features, dockerfile_template=dockerfile_template_path, runtime=runtime, container_registry=container_registry, container_version_strategy=container_version_strategy + bento_tag, llm_fs, llm, workers_per_resource=workers_per_resource, adapter_map=adapter_map, + quantize=quantize, bettertransformer=bettertransformer, extra_dependencies=enable_features, dockerfile_template=dockerfile_template_path, runtime=runtime, + container_registry=container_registry, container_version_strategy=container_version_strategy ) - except Exception as err: - raise err from None + except Exception as err: raise err from None if machine: termui.echo(f"__tag__:{bento.tag}", fg="white") elif output == "pretty": @@ -535,26 +503,17 @@ def build_command( termui.echo("\n" + OPENLLM_FIGLET, fg="white") if not _previously_built: termui.echo(f"Successfully built {bento}.", fg="green") elif not overwrite: termui.echo(f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", fg="yellow") - termui.echo( - "📖 Next steps:\n\n" + "* Push to BentoCloud with 'bentoml push':\n" + f" $ bentoml push {bento.tag}\n\n" + "* Containerize your Bento with 'bentoml containerize':\n" + f" $ bentoml containerize {bento.tag} --opt progress=plain" + "\n\n" + " Tip: To enable additional BentoML features for 'containerize', " + "use '--enable-features=FEATURE[,FEATURE]' " + - "[see 'bentoml containerize -h' for more advanced usage]\n", fg="blue", - ) - elif output == "json": - termui.echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode()) - else: - termui.echo(bento.tag) + termui.echo("📖 Next steps:\n\n" + f"* Push to BentoCloud with 'bentoml push':\n\t$ bentoml push {bento.tag}\n\n" + f"* Containerize your Bento with 'bentoml containerize':\n\t$ bentoml containerize {bento.tag} --opt progress=plain\n\n" + "\tTip: To enable additional BentoML features for 'containerize', use '--enable-features=FEATURE[,FEATURE]' [see 'bentoml containerize -h' for more advanced usage]\n", fg="blue",) + elif output == "json": termui.echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode()) + else: termui.echo(bento.tag) if push: BentoMLContainer.bentocloud_client.get().push_bento(bento, context=t.cast(GlobalOptions, ctx.obj).cloud_context, force=force_push) elif containerize: backend = t.cast("DefaultBuilder", os.environ.get("BENTOML_CONTAINERIZE_BACKEND", "docker")) - try: - bentoml.container.health(backend) - except subprocess.CalledProcessError: - raise OpenLLMException(f"Failed to use backend {backend}") from None - try: - bentoml.container.build(bento.tag, backend=backend, features=("grpc", "io")) - except Exception as err: - raise OpenLLMException(f"Exception caught while containerizing '{bento.tag!s}':\n{err}") from err + try: bentoml.container.health(backend) + except subprocess.CalledProcessError: raise OpenLLMException(f"Failed to use backend {backend}") from None + try: bentoml.container.build(bento.tag, backend=backend, features=("grpc", "io")) + except Exception as err: raise OpenLLMException(f"Exception caught while containerizing '{bento.tag!s}':\n{err}") from err return bento @cli.command() @@ -613,7 +572,7 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo tabulate.PRESERVE_WHITESPACE = True # llm, architecture, url, model_id, installation, cpu, gpu, runtime_impl - data: list[str | tuple[str, str, list[str], str, t.LiteralString, t.LiteralString, tuple[LiteralRuntime, ...]]] = [] + data: list[str | tuple[str, str, list[str], str, LiteralString, LiteralString, tuple[LiteralRuntime, ...]]] = [] for m, v in json_data.items(): data.extend([(m, v["architecture"], v["model_id"], v["installation"], "❌" if not v["cpu"] else "✅", "✅", v["runtime_impl"],)]) column_widths = [int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 12), int(termui.COLUMNS / 12), int(termui.COLUMNS / 12), int(termui.COLUMNS / 4),] @@ -622,7 +581,7 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo termui.echo("Exception found while parsing models:\n", fg="yellow") for m, err in failed_initialized: termui.echo(f"- {m}: ", fg="yellow", nl=False) - termui.echo(traceback.print_exception(err, limit=3), fg="red") + termui.echo(traceback.print_exception(None, err, None, limit=5), fg="red") # type: ignore[func-returns-value] sys.exit(1) table = tabulate.tabulate(data, tablefmt="fancy_grid", headers=["LLM", "Architecture", "Models Id", "pip install", "CPU", "GPU", "Runtime"], maxcolwidths=column_widths) @@ -699,7 +658,7 @@ def shared_client_options(f: _AnyCallable | None = None, output_value: t.Literal @click.option("--remote", is_flag=True, default=False, help="Whether or not to use remote tools (inference endpoints) instead of local ones.", show_default=True) @click.option("--opt", help="Define prompt options. " "(format: ``--opt text='I love this' --opt audio:./path/to/audio --opt image:/path/to/file``)", required=False, multiple=True, callback=opt_callback, metavar="ARG=VALUE[,ARG=VALUE]") -def instruct_command(endpoint: str, timeout: int, agent: t.LiteralString, output: LiteralOutput, remote: bool, task: str, _memoized: DictStrAny, **attrs: t.Any) -> str: +def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output: LiteralOutput, remote: bool, task: str, _memoized: DictStrAny, **attrs: t.Any) -> str: """Instruct agents interactively for given tasks, from a terminal. \b diff --git a/src/openllm/cli/extension/list_models.py b/src/openllm/cli/extension/list_models.py index 22b024db..e1265f8c 100644 --- a/src/openllm/cli/extension/list_models.py +++ b/src/openllm/cli/extension/list_models.py @@ -1,20 +1,11 @@ - from __future__ import annotations -import typing as t - -import click -import inflection -import orjson - -import bentoml -import openllm +import typing as t, bentoml, openllm, orjson, inflection ,click from bentoml._internal.utils import human_readable_size -from .. import termui -from .._factory import LiteralOutput, model_name_argument, output_option +from openllm.cli import termui +from openllm.cli._factory import LiteralOutput, model_name_argument, output_option -if t.TYPE_CHECKING: - from ..._types import DictStrAny +if t.TYPE_CHECKING: from openllm._typing_compat import DictStrAny @click.command("list_models", context_settings=termui.CONTEXT_SETTINGS) @model_name_argument(required=False) diff --git a/src/openllm/cli/extension/playground.py b/src/openllm/cli/extension/playground.py index c494d648..f9ce0e51 100644 --- a/src/openllm/cli/extension/playground.py +++ b/src/openllm/cli/extension/playground.py @@ -1,26 +1,13 @@ - from __future__ import annotations -import importlib.machinery -import logging -import os -import pkgutil -import subprocess -import sys -import tempfile -import typing as t - -import click -import yaml - -from .. import termui -from ... import playground -from ...utils import is_jupyter_available, is_jupytext_available, is_notebook_available +import importlib.machinery, logging, os, pkgutil, subprocess, sys, tempfile, typing as t +import click, yaml +from openllm.cli import termui +from openllm import playground +from openllm.utils import is_jupyter_available, is_jupytext_available, is_notebook_available if t.TYPE_CHECKING: - import jupytext - import nbformat - - from openllm._types import DictStrAny + import jupytext, nbformat + from openllm._typing_compat import DictStrAny logger = logging.getLogger(__name__) diff --git a/src/openllm/cli/termui.py b/src/openllm/cli/termui.py index 267d03cc..05de8746 100644 --- a/src/openllm/cli/termui.py +++ b/src/openllm/cli/termui.py @@ -1,22 +1,11 @@ - from __future__ import annotations -import os -import typing as t - -import click -import inflection - -import openllm - -if t.TYPE_CHECKING: - from .._types import DictStrAny +import os, typing as t, click, inflection, openllm +if t.TYPE_CHECKING: from openllm._typing_compat import DictStrAny def echo(text: t.Any, fg: str = "green", _with_style: bool = True, **attrs: t.Any) -> None: attrs["fg"] = fg if not openllm.utils.get_debug_mode() else None if not openllm.utils.get_quiet_mode(): t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(text, **attrs) COLUMNS: int = int(os.environ.get("COLUMNS", str(120))) - CONTEXT_SETTINGS: DictStrAny = {"help_option_names": ["-h", "--help"], "max_content_width": COLUMNS, "token_normalize_func": inflection.underscore} - __all__ = ["echo", "COLUMNS", "CONTEXT_SETTINGS"] diff --git a/src/openllm/client/__init__.py b/src/openllm/client/__init__.py index aeb4448b..4cb5cfd5 100644 --- a/src/openllm/client/__init__.py +++ b/src/openllm/client/__init__.py @@ -1,12 +1,18 @@ -"""The actual client implementation. +"""OpenLLM Python client. -Use ``openllm.client`` instead. -This holds the implementation of the client, which is used to communicate with the -OpenLLM server. It is used to send requests to the server, and receive responses. +```python +client = openllm.client.HTTPClient("http://localhost:8080") +client.query("What is the difference between gather and scatter?") +``` + +If the server has embedding supports, use it via `client.embed`: +```python +client.embed("What is the difference between gather and scatter?") +``` """ from __future__ import annotations -from .runtimes import ( +from openllm.client.runtimes import ( AsyncGrpcClient as AsyncGrpcClient, AsyncHTTPClient as AsyncHTTPClient, BaseAsyncClient as BaseAsyncClient, diff --git a/src/openllm/client/runtimes/__init__.py b/src/openllm/client/runtimes/__init__.py index 99a7612d..c0ee21e5 100644 --- a/src/openllm/client/runtimes/__init__.py +++ b/src/openllm/client/runtimes/__init__.py @@ -1,15 +1,15 @@ """Client that supports REST/gRPC protocol to interact with a LLMServer.""" from __future__ import annotations -from .base import ( +from openllm.client.runtimes.base import ( BaseAsyncClient as BaseAsyncClient, BaseClient as BaseClient, ) -from .grpc import ( +from openllm.client.runtimes.grpc import ( AsyncGrpcClient as AsyncGrpcClient, GrpcClient as GrpcClient, ) -from .http import ( +from openllm.client.runtimes.http import ( AsyncHTTPClient as AsyncHTTPClient, HTTPClient as HTTPClient, ) diff --git a/src/openllm/client/runtimes/base.py b/src/openllm/client/runtimes/base.py index 97f42573..75fa46c4 100644 --- a/src/openllm/client/runtimes/base.py +++ b/src/openllm/client/runtimes/base.py @@ -1,51 +1,35 @@ +# mypy: disable-error-code="name-defined" from __future__ import annotations -import asyncio -import logging -import sys -import typing as t +import asyncio, logging, typing as t +import bentoml, bentoml.client, openllm, httpx from abc import abstractmethod from http import HTTPStatus from urllib.parse import urljoin - -import httpx - -import bentoml -import openllm -from bentoml._internal.client import Client - -# NOTE: We need to do this so that overload can register -# correct overloads to typing registry -if sys.version_info[:2] >= (3, 11): - from typing import overload -else: - from typing_extensions import overload +from openllm._typing_compat import overload, LiteralString T = t.TypeVar("T") +T_co = t.TypeVar("T_co", covariant=True) if t.TYPE_CHECKING: import transformers + from openllm._typing_compat import DictStrAny, LiteralRuntime +transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") # noqa: F811 - from openllm._types import DictStrAny, LiteralRuntime - - class AnnotatedClient(Client, t.Generic[T]): - def health(self, *args: t.Any, **attrs: t.Any) -> t.Any: - ... - - async def async_health(self) -> t.Any: - ... - - def generate_v1(self, qa: openllm.GenerationInput) -> T: - ... - - def metadata_v1(self) -> T: - ... - - def embeddings_v1(self) -> t.Sequence[float]: - ... -else: - AnnotatedClient = Client - transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") - DictStrAny = dict +class AnnotatedClient(t.Protocol[T_co]): + server_url: str + _svc: bentoml.Service + endpoints: list[str] + def health(self, *args: t.Any, **attrs: t.Any) -> t.Any: ... + async def async_health(self) -> t.Any: ... + def generate_v1(self, qa: openllm.GenerationInput) -> T_co: ... + def metadata_v1(self) -> T_co: ... + def embeddings_v1(self) -> t.Sequence[float]: ... + def call(self, name: str, *args: t.Any, **attrs: t.Any) -> T_co: ... + async def async_call(self, name: str, *args: t.Any, **attrs: t.Any) -> T_co: ... + @staticmethod + def wait_until_server_ready(host: str, port: int, timeout: float = 30, **kwargs: t.Any) -> None: ... + @staticmethod + def from_url(server_url: str) -> AnnotatedClient[t.Any]: ... logger = logging.getLogger(__name__) @@ -53,12 +37,11 @@ def in_async_context() -> bool: try: _ = asyncio.get_running_loop() return True - except RuntimeError: - return False + except RuntimeError: return False class ClientMeta(t.Generic[T]): _api_version: str - _client_class: type[bentoml.client.Client] + _client_type: t.Literal["GrpcClient", "HTTPClient"] _host: str _port: str @@ -66,15 +49,8 @@ class ClientMeta(t.Generic[T]): __agent__: transformers.HfAgent | None = None __llm__: openllm.LLM[t.Any, t.Any] | None = None - def __init__(self, address: str, timeout: int = 30): - self._address = address - self._timeout = timeout - - def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"): - """Initialise subclass for HTTP and gRPC client type.""" - cls._client_class = t.cast(t.Type[bentoml.client.Client], bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient) - cls._api_version = api_version - + def __init__(self, address: str, timeout: int = 30): self._address,self._timeout = address,timeout + def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"): cls._client_type, cls._api_version = "HTTPClient" if client_type == "http" else "GrpcClient", api_version @property def _hf_agent(self) -> transformers.HfAgent: if not self.supports_hf_agent: raise openllm.exceptions.OpenLLMException(f"{self.model_name} ({self.framework}) does not support running HF agent.") @@ -82,99 +58,62 @@ class ClientMeta(t.Generic[T]): if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'") self.__agent__ = transformers.HfAgent(urljoin(self._address, "/hf/agent")) return self.__agent__ - @property - def _metadata(self) -> T: - if in_async_context(): return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json() - return self.call("metadata") - + def _metadata(self) -> T: return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json() if in_async_context() else self.call("metadata") @property @abstractmethod - def model_name(self) -> str: - raise NotImplementedError - + def model_name(self) -> str: raise NotImplementedError @property @abstractmethod - def framework(self) -> LiteralRuntime: - raise NotImplementedError - + def framework(self) -> LiteralRuntime: raise NotImplementedError @property @abstractmethod - def timeout(self) -> int: - raise NotImplementedError - + def timeout(self) -> int: raise NotImplementedError @property @abstractmethod - def model_id(self) -> str: - raise NotImplementedError - + def model_id(self) -> str: raise NotImplementedError @property @abstractmethod - def configuration(self) -> dict[str, t.Any]: - raise NotImplementedError - + def configuration(self) -> dict[str, t.Any]: raise NotImplementedError @property @abstractmethod - def supports_embeddings(self) -> bool: - raise NotImplementedError - + def supports_embeddings(self) -> bool: raise NotImplementedError @property @abstractmethod - def supports_hf_agent(self) -> bool: - raise NotImplementedError + def supports_hf_agent(self) -> bool: raise NotImplementedError + @abstractmethod + def postprocess(self, result: t.Any) -> openllm.GenerationOutput: ... + @abstractmethod + def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: ... + @property + def config(self) -> openllm.LLMConfig: return self.llm.config @property def llm(self) -> openllm.LLM[t.Any, t.Any]: + # XXX: if the server runs vllm or any framework that is not available from the user client, client will fail. if self.__llm__ is None: self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name) return self.__llm__ - @property - def config(self) -> openllm.LLMConfig: - return self.llm.config - - def call(self, name: str, *args: t.Any, **attrs: t.Any) -> T: - return self._cached.call(f"{name}_{self._api_version}", *args, **attrs) - - async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> T: - return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs) - + def call(self, name: str, *args: t.Any, **attrs: t.Any) -> T: return self._cached.call(f"{name}_{self._api_version}", *args, **attrs) + async def acall(self, name: str, *args: t.Any, **attrs: t.Any) -> T: return await self._cached.async_call(f"{name}_{self._api_version}", *args, **attrs) @property def _cached(self) -> AnnotatedClient[T]: + client_class = t.cast(AnnotatedClient[T], getattr(bentoml.client, self._client_type)) if self.__client__ is None: - self._client_class.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout) - self.__client__ = t.cast("AnnotatedClient[T]", self._client_class.from_url(self._address)) + client_class.wait_until_server_ready(self._host, int(self._port), timeout=self._timeout) + self.__client__ = client_class.from_url(self._address) return self.__client__ - @abstractmethod - def postprocess(self, result: t.Any) -> openllm.GenerationOutput: - ... - - @abstractmethod - def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: - ... - class BaseClient(ClientMeta[T]): - def health(self) -> t.Any: - raise NotImplementedError - - def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: - raise NotImplementedError - - def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: - raise NotImplementedError - + def health(self) -> t.Any: raise NotImplementedError + def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError + def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError @overload - def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: - ... - + def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ... @overload - def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: - ... - + def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ... @overload - def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: - ... - + def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ... def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: return_raw_response = attrs.pop("return_raw_response", None) if return_raw_response is not None: @@ -197,21 +136,14 @@ class BaseClient(ClientMeta[T]): # NOTE: Scikit interface @overload - def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: - ... - + def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ... @overload - def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: - ... - + def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ... @overload - def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: - ... + def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ... + def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], self.query(prompt, **attrs)) - def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: - return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], self.query(prompt, **attrs)) - - def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any: + def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = "hf", **attrs: t.Any) -> t.Any: if agent_type == "hf": return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs) else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'") @@ -220,34 +152,21 @@ class BaseClient(ClientMeta[T]): task = kwargs.pop("task", args[0]) return_code = kwargs.pop("return_code", False) remote = kwargs.pop("remote", False) - try: - return self._hf_agent.run(task, return_code=return_code, remote=remote, **kwargs) + try: return self._hf_agent.run(task, return_code=return_code, remote=remote, **kwargs) except Exception as err: logger.error("Exception caught while sending instruction to HF agent: %s", err, exc_info=err) - logger.info("Tip: LLMServer at '%s' might not support single generation yet.", self._address) + logger.info("Tip: LLMServer at '%s' might not support 'generate_one'.", self._address) class BaseAsyncClient(ClientMeta[T]): - async def health(self) -> t.Any: - raise NotImplementedError - - async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: - raise NotImplementedError - - async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: - raise NotImplementedError - + async def health(self) -> t.Any: raise NotImplementedError + async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str: raise NotImplementedError + async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: raise NotImplementedError @overload - async def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: - ... - + async def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ... @overload - async def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: - ... - + async def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ... @overload - async def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: - ... - + async def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ... async def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: return_raw_response = attrs.pop("return_raw_response", None) if return_raw_response is not None: @@ -270,25 +189,16 @@ class BaseAsyncClient(ClientMeta[T]): # NOTE: Scikit interface @overload - async def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: - ... - + async def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str: ... @overload - async def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: - ... - + async def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny: ... @overload - async def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: - ... - - async def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: - return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], await self.query(prompt, **attrs)) - - async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: t.LiteralString = "hf", **attrs: t.Any) -> t.Any: + async def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm.GenerationOutput: ... + async def predict(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | DictStrAny | str: return t.cast(t.Union[openllm.GenerationOutput, DictStrAny, str], await self.query(prompt, **attrs)) + async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = "hf", **attrs: t.Any) -> t.Any: """Async version of agent.run.""" if agent_type == "hf": return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs) else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'") - async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: if not openllm.utils.is_transformers_supports_agent(): raise RuntimeError("This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0") if len(args) > 1: raise ValueError("'args' should only take one positional argument.") @@ -317,9 +227,7 @@ class BaseAsyncClient(ClientMeta[T]): # the below have the same logic as agent.run API explanation, code = clean_code_for_run(result) - _hf_agent.log(f"==Explanation from the agent==\n{explanation}") - _hf_agent.log(f"\n\n==Code generated by the agent==\n{code}") if not return_code: _hf_agent.log("\n\n==Result==") diff --git a/src/openllm/client/runtimes/grpc.py b/src/openllm/client/runtimes/grpc.py index 44069cc4..0115d80d 100644 --- a/src/openllm/client/runtimes/grpc.py +++ b/src/openllm/client/runtimes/grpc.py @@ -1,102 +1,93 @@ from __future__ import annotations -import asyncio -import logging -import typing as t - -import orjson - -import openllm - +import asyncio, logging, typing as t +import orjson, openllm +from openllm._typing_compat import LiteralRuntime from .base import BaseAsyncClient, BaseClient if t.TYPE_CHECKING: from grpc_health.v1 import health_pb2 - from bentoml.grpc.v1.service_pb2 import Response logger = logging.getLogger(__name__) -LiteralRuntime = t.Literal["pt", "tf", "flax", "vllm"] - -class GrpcClientMixin: - if t.TYPE_CHECKING: - - @property - def _metadata(self) -> Response: - ... - +class GrpcClient(BaseClient["Response"], client_type="grpc"): + def __init__(self, address: str, timeout: int = 30): + self._host, self._port = address.split(":") + super().__init__(address, timeout) + def health(self) -> health_pb2.HealthCheckResponse: return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService")) @property def model_name(self) -> str: - try: - return self._metadata.json.struct_value.fields["model_name"].string_value - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata.json.struct_value.fields["model_name"].string_value + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def framework(self) -> LiteralRuntime: try: value = t.cast(LiteralRuntime, self._metadata.json.struct_value.fields["framework"].string_value) if value not in ("pt", "flax", "tf", "vllm"): raise KeyError return value - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def timeout(self) -> int: - try: - return int(self._metadata.json.struct_value.fields["timeout"].number_value) - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return int(self._metadata.json.struct_value.fields["timeout"].number_value) + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def model_id(self) -> str: - try: - return self._metadata.json.struct_value.fields["model_id"].string_value - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata.json.struct_value.fields["model_id"].string_value + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def configuration(self) -> dict[str, t.Any]: - try: - v = self._metadata.json.struct_value.fields["configuration"].string_value - return orjson.loads(v) - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return orjson.loads(self._metadata.json.struct_value.fields["configuration"].string_value) + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def supports_embeddings(self) -> bool: - try: - return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def supports_hf_agent(self) -> bool: - try: - return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput: - if isinstance(result, dict): - return openllm.GenerationOutput(**result) - from google.protobuf.json_format import MessageToDict - + if isinstance(result, dict): return openllm.GenerationOutput(**result) return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True)) -class GrpcClient(GrpcClientMixin, BaseClient["Response"], client_type="grpc"): +class AsyncGrpcClient(BaseAsyncClient["Response"], client_type="grpc"): def __init__(self, address: str, timeout: int = 30): self._host, self._port = address.split(":") super().__init__(address, timeout) - - def health(self) -> health_pb2.HealthCheckResponse: - return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService")) - -class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient["Response"], client_type="grpc"): - def __init__(self, address: str, timeout: int = 30): - self._host, self._port = address.split(":") - super().__init__(address, timeout) - - async def health(self) -> health_pb2.HealthCheckResponse: - return await self._cached.health("bentoml.grpc.v1.BentoService") + async def health(self) -> health_pb2.HealthCheckResponse: return await self._cached.health("bentoml.grpc.v1.BentoService") + @property + def model_name(self) -> str: + try: return self._metadata.json.struct_value.fields["model_name"].string_value + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def framework(self) -> LiteralRuntime: + try: + value = t.cast(LiteralRuntime, self._metadata.json.struct_value.fields["framework"].string_value) + if value not in ("pt", "flax", "tf", "vllm"): raise KeyError + return value + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def timeout(self) -> int: + try: return int(self._metadata.json.struct_value.fields["timeout"].number_value) + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def model_id(self) -> str: + try: return self._metadata.json.struct_value.fields["model_id"].string_value + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def configuration(self) -> dict[str, t.Any]: + try: return orjson.loads(self._metadata.json.struct_value.fields["configuration"].string_value) + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def supports_embeddings(self) -> bool: + try: return self._metadata.json.struct_value.fields["supports_embeddings"].bool_value + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def supports_hf_agent(self) -> bool: + try: return self._metadata.json.struct_value.fields["supports_hf_agent"].bool_value + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput: + from google.protobuf.json_format import MessageToDict + if isinstance(result, dict): return openllm.GenerationOutput(**result) + return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True)) diff --git a/src/openllm/client/runtimes/http.py b/src/openllm/client/runtimes/http.py index 6125ccff..f59f6f5a 100644 --- a/src/openllm/client/runtimes/http.py +++ b/src/openllm/client/runtimes/http.py @@ -1,110 +1,98 @@ from __future__ import annotations -import logging -import typing as t +import logging, typing as t from urllib.parse import urljoin, urlparse - -import httpx -import orjson - -import openllm - +import httpx, orjson, openllm from .base import BaseAsyncClient, BaseClient, in_async_context - -if t.TYPE_CHECKING: - from openllm._types import DictStrAny, LiteralRuntime -else: - DictStrAny = dict +from openllm._typing_compat import DictStrAny, LiteralRuntime logger = logging.getLogger(__name__) +def process_address(self: AsyncHTTPClient | HTTPClient, address: str) -> None: + address = address if "://" in address else "http://" + address + parsed = urlparse(address) + self._host, *_port = parsed.netloc.split(":") + if len(_port) == 0: self._port = "80" if parsed.scheme == "http" else "443" + else: self._port = next(iter(_port)) -class HTTPClientMixin: - if t.TYPE_CHECKING: +class HTTPClient(BaseClient[DictStrAny]): + def __init__(self, address: str, timeout: int = 30): + process_address(self, address) + super().__init__(address, timeout) - @property - def _metadata(self) -> DictStrAny: - ... + def health(self) -> t.Any: return self._cached.health() + def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: + if not self.supports_embeddings: raise ValueError("This model does not support embeddings.") + if isinstance(prompt, str): prompt = [prompt] + result = httpx.post(urljoin(self._address, f"/{self._api_version}/embeddings"), json=list(prompt), timeout=self.timeout).json() if in_async_context() else self.call("embeddings", list(prompt)) + return openllm.EmbeddingsOutput(**result) @property def model_name(self) -> str: - try: - return self._metadata["model_name"] - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata["model_name"] + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def model_id(self) -> str: - try: - return self._metadata["model_name"] - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata["model_name"] + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def framework(self) -> LiteralRuntime: - try: - return self._metadata["framework"] - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata["framework"] + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def timeout(self) -> int: - try: - return self._metadata["timeout"] - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata["timeout"] + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def configuration(self) -> dict[str, t.Any]: - try: - return orjson.loads(self._metadata["configuration"]) - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return orjson.loads(self._metadata["configuration"]) + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def supports_embeddings(self) -> bool: - try: - return self._metadata.get("supports_embeddings", False) - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None - + try: return self._metadata.get("supports_embeddings", False) + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def supports_hf_agent(self) -> bool: - try: - return self._metadata.get("supports_hf_agent", False) - except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + try: return self._metadata.get("supports_hf_agent", False) + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput: return openllm.GenerationOutput(**result) - def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput: - return openllm.GenerationOutput(**result) - -class HTTPClient(HTTPClientMixin, BaseClient[DictStrAny]): +class AsyncHTTPClient(BaseAsyncClient[DictStrAny]): def __init__(self, address: str, timeout: int = 30): - address = address if "://" in address else "http://" + address - self._host, self._port = urlparse(address).netloc.split(":") + process_address(self, address) super().__init__(address, timeout) - def health(self) -> t.Any: - return self._cached.health() - - def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: - if not self.supports_embeddings: - raise ValueError("This model does not support embeddings.") - if isinstance(prompt, str): prompt = [prompt] - if in_async_context(): result = httpx.post(urljoin(self._address, f"/{self._api_version}/embeddings"), json=list(prompt), timeout=self.timeout).json() - else: result = self.call("embeddings", list(prompt)) - return openllm.EmbeddingsOutput(**result) - -class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient[DictStrAny]): - def __init__(self, address: str, timeout: int = 30): - address = address if "://" in address else "http://" + address - self._host, self._port = urlparse(address).netloc.split(":") - super().__init__(address, timeout) - - async def health(self) -> t.Any: - return await self._cached.async_health() - + async def health(self) -> t.Any: return await self._cached.async_health() async def embed(self, prompt: t.Sequence[str] | str) -> openllm.EmbeddingsOutput: - if not self.supports_embeddings: - raise ValueError("This model does not support embeddings.") + if not self.supports_embeddings: raise ValueError("This model does not support embeddings.") if isinstance(prompt, str): prompt = [prompt] res = await self.acall("embeddings", list(prompt)) return openllm.EmbeddingsOutput(**res) + + @property + def model_name(self) -> str: + try: return self._metadata["model_name"] + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def model_id(self) -> str: + try: return self._metadata["model_name"] + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def framework(self) -> LiteralRuntime: + try: return self._metadata["framework"] + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def timeout(self) -> int: + try: return self._metadata["timeout"] + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def configuration(self) -> dict[str, t.Any]: + try: return orjson.loads(self._metadata["configuration"]) + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def supports_embeddings(self) -> bool: + try: return self._metadata.get("supports_embeddings", False) + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + @property + def supports_hf_agent(self) -> bool: + try: return self._metadata.get("supports_hf_agent", False) + except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None + def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput: return openllm.GenerationOutput(**result) diff --git a/src/openllm/exceptions.py b/src/openllm/exceptions.py index 45463ec3..86e5d294 100644 --- a/src/openllm/exceptions.py +++ b/src/openllm/exceptions.py @@ -1,28 +1,19 @@ """Base exceptions for OpenLLM. This extends BentoML exceptions.""" from __future__ import annotations - import bentoml - class OpenLLMException(bentoml.exceptions.BentoMLException): """Base class for all OpenLLM exceptions. This extends BentoMLException.""" - class GpuNotAvailableError(OpenLLMException): """Raised when there is no GPU available in given system.""" - class ValidationError(OpenLLMException): """Raised when a validation fails.""" - class ForbiddenAttributeError(OpenLLMException): """Raised when using an _internal field.""" - class MissingAnnotationAttributeError(OpenLLMException): """Raised when a field under openllm.LLMConfig is missing annotations.""" - class MissingDependencyError(BaseException): """Raised when a dependency is missing.""" - class Error(BaseException): """To be used instead of naked raise.""" - class FineTuneStrategyNotSupportedError(OpenLLMException): """Raised when a fine-tune strategy is not supported for given LLM.""" diff --git a/src/openllm/models/__init__.py b/src/openllm/models/__init__.py index 8104c7fd..2fea005e 100644 --- a/src/openllm/models/__init__.py +++ b/src/openllm/models/__init__.py @@ -1,11 +1,11 @@ # This file is generated by tools/update-models-import.py. DO NOT EDIT MANUALLY! # To update this, run ./tools/update-models-import.py from __future__ import annotations -import typing as t +import typing as t, os from openllm.utils import LazyModule _MODELS: set[str] = {"auto", "baichuan", "chatglm", "dolly_v2", "falcon", "flan_t5", "gpt_neox", "llama", "mpt", "opt", "stablelm", "starcoder"} if t.TYPE_CHECKING: from . import auto as auto,baichuan as baichuan,chatglm as chatglm,dolly_v2 as dolly_v2,falcon as falcon,flan_t5 as flan_t5,gpt_neox as gpt_neox,llama as llama,mpt as mpt,opt as opt,stablelm as stablelm,starcoder as starcoder -__lazy=LazyModule(__name__, globals()["__file__"], {k: [] for k in _MODELS}) +__lazy=LazyModule(__name__, os.path.abspath("__file__"), {k: [] for k in _MODELS}) __all__=__lazy.__all__ __dir__=__lazy.__dir__ __getattr__=__lazy.__getattr__ diff --git a/src/openllm/models/auto/__init__.py b/src/openllm/models/auto/__init__.py index 60c63f4c..016e96f3 100644 --- a/src/openllm/models/auto/__init__.py +++ b/src/openllm/models/auto/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys, typing as t - +import typing as t, os import openllm from openllm.utils import LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vllm_available @@ -40,4 +39,7 @@ else: _import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"]) if t.TYPE_CHECKING: from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING, AutoTFLLM as AutoTFLLM -sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure) +__lazy=LazyModule(__name__, os.path.abspath("__file__"), _import_structure) +__all__=__lazy.__all__ +__dir__=__lazy.__dir__ +__getattr__=__lazy.__getattr__ diff --git a/src/openllm/models/auto/configuration_auto.py b/src/openllm/models/auto/configuration_auto.py index 3e759a3b..ff7e7c70 100644 --- a/src/openllm/models/auto/configuration_auto.py +++ b/src/openllm/models/auto/configuration_auto.py @@ -1,16 +1,15 @@ +# mypy: disable-error-code="type-arg" from __future__ import annotations import typing as t from collections import OrderedDict -import inflection - -import openllm +import inflection, openllm from openllm.utils import ReprMixin if t.TYPE_CHECKING: import types + from openllm._typing_compat import LiteralString from collections import _odict_items, _odict_keys, _odict_values - ConfigOrderedDict = OrderedDict[str, type[openllm.LLMConfig]] ConfigKeysView = _odict_keys[str, type[openllm.LLMConfig]] ConfigValuesView = _odict_values[str, type[openllm.LLMConfig]] ConfigItemsView = _odict_items[str, type[openllm.LLMConfig]] @@ -18,8 +17,8 @@ if t.TYPE_CHECKING: # NOTE: This is the entrypoint when adding new model config CONFIG_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLMConfig"), ("dolly_v2", "DollyV2Config"), ("falcon", "FalconConfig"), ("flan_t5", "FlanT5Config"), ("gpt_neox", "GPTNeoXConfig"), ("llama", "LlamaConfig"), ("mpt", "MPTConfig"), ("opt", "OPTConfig"), ("stablelm", "StableLMConfig"), ("starcoder", "StarCoderConfig"), ("baichuan", "BaichuanConfig")]) -class _LazyConfigMapping(OrderedDict, ReprMixin): # type: ignore[type-arg] - def __init__(self, mapping: OrderedDict[t.LiteralString, t.LiteralString]): +class _LazyConfigMapping(OrderedDict, ReprMixin): + def __init__(self, mapping: OrderedDict[LiteralString, LiteralString]): self._mapping = mapping self._extra_content: dict[str, t.Any] = {} self._modules: dict[str, types.ModuleType] = {} diff --git a/src/openllm/models/auto/factory.py b/src/openllm/models/auto/factory.py index 6a2ebaa8..1e82124b 100644 --- a/src/openllm/models/auto/factory.py +++ b/src/openllm/models/auto/factory.py @@ -1,23 +1,16 @@ +# mypy: disable-error-code="type-arg" from __future__ import annotations -import importlib -import inspect -import logging -import typing as t +import importlib, inspect, logging, typing as t from collections import OrderedDict - -import inflection - -import openllm +import inflection, openllm from openllm.utils import ReprMixin if t.TYPE_CHECKING: + from openllm._typing_compat import LiteralString, LLMRunner import types from collections import _odict_items, _odict_keys, _odict_values from _typeshed import SupportsIter - - from ..._llm import LLMRunner - ConfigModelOrderedDict = OrderedDict[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] ConfigModelKeysView = _odict_keys[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] ConfigModelValuesView = _odict_values[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] ConfigModelItemsView = _odict_items[type[openllm.LLMConfig], type[openllm.LLM[t.Any, t.Any]]] @@ -25,7 +18,7 @@ if t.TYPE_CHECKING: logger = logging.getLogger(__name__) class BaseAutoLLMClass: - _model_mapping: _LazyAutoMapping + _model_mapping: t.ClassVar[_LazyAutoMapping] def __init__(self, *args: t.Any, **attrs: t.Any): raise EnvironmentError(f"Cannot instantiate {self.__class__.__name__} directly. Please use '{self.__class__.__name__}.Runner(model_name)' instead.") @classmethod def for_model(cls, model: str, /, model_id: str | None = None, model_version: str | None = None, llm_config: openllm.LLMConfig | None = None, ensure_available: bool = False, **attrs: t.Any) -> openllm.LLM[t.Any, t.Any]: @@ -83,13 +76,13 @@ def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any: except ValueError: raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None raise ValueError(f"Could not find {attr} in {openllm_module}!") -class _LazyAutoMapping(OrderedDict, ReprMixin): # type: ignore[type-arg] +class _LazyAutoMapping(OrderedDict, ReprMixin): """Based on transformers.models.auto.configuration_auto._LazyAutoMapping. This OrderedDict values() and keys() returns the list instead, so you don't have to do list(mapping.values()) to get the list of values. """ - def __init__(self, config_mapping: OrderedDict[t.LiteralString, t.LiteralString], model_mapping: OrderedDict[t.LiteralString, t.LiteralString]): + def __init__(self, config_mapping: OrderedDict[LiteralString, LiteralString], model_mapping: OrderedDict[LiteralString, LiteralString]): self._config_mapping = config_mapping self._reverse_config_mapping = {v: k for k, v in config_mapping.items()} self._model_mapping = model_mapping diff --git a/src/openllm/models/auto/modeling_auto.py b/src/openllm/models/auto/modeling_auto.py index 1a508a07..13d3cd1d 100644 --- a/src/openllm/models/auto/modeling_auto.py +++ b/src/openllm/models/auto/modeling_auto.py @@ -1,11 +1,10 @@ - from __future__ import annotations +import typing as t from collections import OrderedDict - from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import BaseAutoLLMClass, _LazyAutoMapping MODEL_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLM"), ("dolly_v2", "DollyV2"), ("falcon", "Falcon"), ("flan_t5", "FlanT5"), ("gpt_neox", "GPTNeoX"), ("llama", "Llama"), ("mpt", "MPT"), ("opt", "OPT"), ("stablelm", "StableLM"), ("starcoder", "StarCoder"), ("baichuan", "Baichuan")]) MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) class AutoLLM(BaseAutoLLMClass): - _model_mapping = MODEL_MAPPING + _model_mapping: t.ClassVar = MODEL_MAPPING diff --git a/src/openllm/models/auto/modeling_flax_auto.py b/src/openllm/models/auto/modeling_flax_auto.py index 0e334af7..20d45f6f 100644 --- a/src/openllm/models/auto/modeling_flax_auto.py +++ b/src/openllm/models/auto/modeling_flax_auto.py @@ -1,10 +1,10 @@ from __future__ import annotations +import typing as t from collections import OrderedDict - from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import BaseAutoLLMClass, _LazyAutoMapping MODEL_FLAX_MAPPING_NAMES = OrderedDict([("flan_t5", "FlaxFlanT5"), ("opt", "FlaxOPT")]) MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES) class AutoFlaxLLM(BaseAutoLLMClass): - _model_mapping = MODEL_FLAX_MAPPING + _model_mapping: t.ClassVar = MODEL_FLAX_MAPPING diff --git a/src/openllm/models/auto/modeling_tf_auto.py b/src/openllm/models/auto/modeling_tf_auto.py index ee164631..9aa6a0a4 100644 --- a/src/openllm/models/auto/modeling_tf_auto.py +++ b/src/openllm/models/auto/modeling_tf_auto.py @@ -1,10 +1,10 @@ from __future__ import annotations +import typing as t from collections import OrderedDict - from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import BaseAutoLLMClass, _LazyAutoMapping MODEL_TF_MAPPING_NAMES = OrderedDict([("flan_t5", "TFFlanT5"), ("opt", "TFOPT")]) MODEL_TF_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES) class AutoTFLLM(BaseAutoLLMClass): - _model_mapping = MODEL_TF_MAPPING + _model_mapping: t.ClassVar = MODEL_TF_MAPPING diff --git a/src/openllm/models/auto/modeling_vllm_auto.py b/src/openllm/models/auto/modeling_vllm_auto.py index 36aa2117..c510441e 100644 --- a/src/openllm/models/auto/modeling_vllm_auto.py +++ b/src/openllm/models/auto/modeling_vllm_auto.py @@ -1,10 +1,10 @@ from __future__ import annotations +import typing as t from collections import OrderedDict - from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import BaseAutoLLMClass, _LazyAutoMapping MODEL_VLLM_MAPPING_NAMES = OrderedDict([("baichuan", "VLLMBaichuan"), ("dolly_v2", "VLLMDollyV2"), ("gpt_neox", "VLLMGPTNeoX"), ("mpt", "VLLMMPT"), ("opt", "VLLMOPT"), ("stablelm", "VLLMStableLM"), ("starcoder", "VLLMStarCoder"), ("llama", "VLLMLlama")]) MODEL_VLLM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES) class AutoVLLM(BaseAutoLLMClass): - _model_mapping = MODEL_VLLM_MAPPING + _model_mapping: t.ClassVar = MODEL_VLLM_MAPPING diff --git a/src/openllm/models/baichuan/__init__.py b/src/openllm/models/baichuan/__init__.py index cab3c6c0..f201ef91 100644 --- a/src/openllm/models/baichuan/__init__.py +++ b/src/openllm/models/baichuan/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations import sys, typing as t - from openllm.exceptions import MissingDependencyError from openllm.utils import LazyModule, is_cpm_kernels_available, is_torch_available, is_vllm_available diff --git a/src/openllm/models/baichuan/configuration_baichuan.py b/src/openllm/models/baichuan/configuration_baichuan.py index 50e132c4..9f7b4122 100644 --- a/src/openllm/models/baichuan/configuration_baichuan.py +++ b/src/openllm/models/baichuan/configuration_baichuan.py @@ -1,5 +1,4 @@ from __future__ import annotations - import openllm class BaichuanConfig(openllm.LLMConfig): diff --git a/src/openllm/models/baichuan/modeling_baichuan.py b/src/openllm/models/baichuan/modeling_baichuan.py index abc2fb83..6b4cd5b9 100644 --- a/src/openllm/models/baichuan/modeling_baichuan.py +++ b/src/openllm/models/baichuan/modeling_baichuan.py @@ -1,9 +1,6 @@ from __future__ import annotations -import sys, typing as t - -import openllm +import typing as t, openllm from openllm._prompt import process_prompt - from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE if t.TYPE_CHECKING: import torch, transformers @@ -15,6 +12,6 @@ class Baichuan(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrai def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0] def generate(self, prompt: str, **attrs: t.Any) -> list[str]: inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) - with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): + with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): # type: ignore[attr-defined] outputs = self.model.generate(**inputs, generation_config=self.config.model_construct_env(**attrs).to_generation_config()) return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) diff --git a/src/openllm/models/baichuan/modeling_vllm_baichuan.py b/src/openllm/models/baichuan/modeling_vllm_baichuan.py index f0daad13..1e9e73d6 100644 --- a/src/openllm/models/baichuan/modeling_vllm_baichuan.py +++ b/src/openllm/models/baichuan/modeling_vllm_baichuan.py @@ -1,15 +1,9 @@ from __future__ import annotations -import logging -import typing as t - -import openllm +import typing as t, openllm from openllm._prompt import process_prompt - from .configuration_baichuan import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import vllm, transformers -logger = logging.getLogger(__name__) class VLLMBaichuan(openllm.LLM["vllm.LLMEngine", "transformers.PreTrainedTokenizerBase"]): __openllm_internal__ = True tokenizer_id = "local" diff --git a/src/openllm/models/chatglm/__init__.py b/src/openllm/models/chatglm/__init__.py index b14952bb..90bacaed 100644 --- a/src/openllm/models/chatglm/__init__.py +++ b/src/openllm/models/chatglm/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations import sys, typing as t - from openllm.exceptions import MissingDependencyError from openllm.utils import LazyModule, is_cpm_kernels_available, is_torch_available diff --git a/src/openllm/models/chatglm/configuration_chatglm.py b/src/openllm/models/chatglm/configuration_chatglm.py index 230a84e0..a8e7e651 100644 --- a/src/openllm/models/chatglm/configuration_chatglm.py +++ b/src/openllm/models/chatglm/configuration_chatglm.py @@ -1,5 +1,4 @@ from __future__ import annotations - import openllm class ChatGLMConfig(openllm.LLMConfig): diff --git a/src/openllm/models/chatglm/modeling_chatglm.py b/src/openllm/models/chatglm/modeling_chatglm.py index 6c1fcd9e..ebcaa35e 100644 --- a/src/openllm/models/chatglm/modeling_chatglm.py +++ b/src/openllm/models/chatglm/modeling_chatglm.py @@ -1,8 +1,5 @@ from __future__ import annotations -import sys, typing as t - -import openllm - +import typing as t, openllm if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional") diff --git a/src/openllm/models/dolly_v2/__init__.py b/src/openllm/models/dolly_v2/__init__.py index 7820e3d1..fb8ce3d2 100644 --- a/src/openllm/models/dolly_v2/__init__.py +++ b/src/openllm/models/dolly_v2/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations import sys, typing as t - from openllm.exceptions import MissingDependencyError from openllm.utils import LazyModule, is_torch_available, is_vllm_available diff --git a/src/openllm/models/dolly_v2/configuration_dolly_v2.py b/src/openllm/models/dolly_v2/configuration_dolly_v2.py index 4d877d32..1b3026a8 100644 --- a/src/openllm/models/dolly_v2/configuration_dolly_v2.py +++ b/src/openllm/models/dolly_v2/configuration_dolly_v2.py @@ -1,8 +1,5 @@ from __future__ import annotations -import sys, typing as t - -import openllm - +import typing as t, openllm if t.TYPE_CHECKING: import transformers class DollyV2Config(openllm.LLMConfig): diff --git a/src/openllm/models/dolly_v2/modeling_dolly_v2.py b/src/openllm/models/dolly_v2/modeling_dolly_v2.py index 3d08dc3e..e9df5328 100644 --- a/src/openllm/models/dolly_v2/modeling_dolly_v2.py +++ b/src/openllm/models/dolly_v2/modeling_dolly_v2.py @@ -1,105 +1,107 @@ from __future__ import annotations -import logging -import re -import typing as t - -import openllm +import logging, re, typing as t, openllm from openllm._prompt import process_prompt - +from openllm._typing_compat import overload from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE, END_KEY, RESPONSE_KEY, get_special_token_id if t.TYPE_CHECKING: import torch, transformers, tensorflow as tf else: torch, transformers, tf = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("tf", globals(), "tensorflow") logger = logging.getLogger(__name__) -# Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information. -class InstructionTextGenerationPipeline(transformers.Pipeline): - def __init__(self, model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast, /, *args: t.Any, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs: t.Any): super().__init__(*args, model=model, tokenizer=tokenizer, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs) - def _sanitize_parameters(self, return_full_text: bool | None = None, **generate_kwargs: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any], dict[str, t.Any]]: - if t.TYPE_CHECKING: assert self.tokenizer is not None - preprocess_params: dict[str, t.Any] = {} - # newer versions of the tokenizer configure the response key as a special token. newer versions still may - # append a newline to yield a single token. find whatever token is configured for the response key. - tokenizer_response_key = next((token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None) - response_key_token_id = None - end_key_token_id = None - if tokenizer_response_key: - try: - response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key) - end_key_token_id = get_special_token_id(self.tokenizer, END_KEY) - # Ensure generation stops once it generates "### End" - generate_kwargs["eos_token_id"] = end_key_token_id - except ValueError: pass - forward_params = generate_kwargs - postprocess_params = {"response_key_token_id": response_key_token_id, "end_key_token_id": end_key_token_id} - if return_full_text is not None: postprocess_params["return_full_text"] = return_full_text - return preprocess_params, forward_params, postprocess_params - def preprocess(self, input_: str, **generate_kwargs: t.Any) -> dict[str, t.Any]: - if t.TYPE_CHECKING: assert self.tokenizer is not None - prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=input_) - inputs = self.tokenizer(prompt_text, return_tensors="pt") - inputs["prompt_text"] = prompt_text - inputs["instruction_text"] = input_ - return inputs - def _forward(self, model_inputs: dict[str, t.Any], **generate_kwargs: t.Any) -> dict[str, t.Any]: - if t.TYPE_CHECKING: assert self.tokenizer is not None - input_ids, attention_mask = model_inputs["input_ids"], model_inputs.get("attention_mask", None) - if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1 - else: in_b = input_ids.shape[0] - generated_sequence = self.model.generate(input_ids=input_ids.to(self.model.device) if input_ids is not None else None, attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, pad_token_id=self.tokenizer.pad_token_id, **generate_kwargs) - out_b = generated_sequence.shape[0] - if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) - elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:])) - instruction_text = model_inputs.pop("instruction_text") - return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text} - def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False) -> list[dict[t.Literal["generated_text"], str]]: - if t.TYPE_CHECKING: assert self.tokenizer is not None - _generated_sequence, instruction_text = model_outputs["generated_sequence"][0], model_outputs["instruction_text"] - generated_sequence: list[list[int]] = _generated_sequence.numpy().tolist() - records: list[dict[t.Literal["generated_text"], str]] = [] - for sequence in generated_sequence: - # The response will be set to this variable if we can identify it. - decoded = None - # If we have token IDs for the response and end, then we can find the tokens and only decode between them. - if response_key_token_id and end_key_token_id: - # Find where "### Response:" is first found in the generated tokens. Considering this is part of the - # prompt, we should definitely find it. We will return the tokens found after this token. - try: response_pos = sequence.index(response_key_token_id) - except ValueError: response_pos = None - if response_pos is None: logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence) - if response_pos: - # Next find where "### End" is located. The model has been trained to end its responses with this - # sequence (or actually, the token ID it maps to, since it is a special token). We may not find - # this token, as the response could be truncated. If we don't find it then just return everything - # to the end. Note that even though we set eos_token_id, we still see the this token at the end. - try: end_pos = sequence.index(end_key_token_id) - except ValueError: end_pos = None - decoded = self.tokenizer.decode(sequence[response_pos + 1:end_pos]).strip() - if not decoded: - # Otherwise we'll decode everything and use a regex to find the response and end. - fully_decoded = self.tokenizer.decode(sequence) - # The response appears after "### Response:". The model has been trained to append "### End" at the - # end. - m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL) - if m: decoded = m.group(1).strip() - else: - # The model might not generate the "### End" sequence before reaching the max tokens. In this case, - # return everything after "### Response:". - m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL) +@overload +def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[True] = True, **attrs: t.Any) -> transformers.Pipeline: ... +@overload +def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[False] = ..., **attrs: t.Any) -> type[transformers.Pipeline]: ... +def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: bool = False, **attrs: t.Any) -> type[transformers.Pipeline] | transformers.Pipeline: + # Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information. + class InstructionTextGenerationPipeline(transformers.Pipeline): + def __init__(self, *args: t.Any, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs: t.Any): super().__init__(*args, model=model, tokenizer=tokenizer, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs) + def _sanitize_parameters(self, return_full_text: bool | None = None, **generate_kwargs: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any], dict[str, t.Any]]: + if t.TYPE_CHECKING: assert self.tokenizer is not None + preprocess_params: dict[str, t.Any] = {} + # newer versions of the tokenizer configure the response key as a special token. newer versions still may + # append a newline to yield a single token. find whatever token is configured for the response key. + tokenizer_response_key = next((token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None) + response_key_token_id = None + end_key_token_id = None + if tokenizer_response_key: + try: + response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key) + end_key_token_id = get_special_token_id(self.tokenizer, END_KEY) + # Ensure generation stops once it generates "### End" + generate_kwargs["eos_token_id"] = end_key_token_id + except ValueError: pass + forward_params = generate_kwargs + postprocess_params = {"response_key_token_id": response_key_token_id, "end_key_token_id": end_key_token_id} + if return_full_text is not None: postprocess_params["return_full_text"] = return_full_text + return preprocess_params, forward_params, postprocess_params + def preprocess(self, input_: str, **generate_kwargs: t.Any) -> t.Dict[str, t.Any]: + if t.TYPE_CHECKING: assert self.tokenizer is not None + prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=input_) + inputs = self.tokenizer(prompt_text, return_tensors="pt") + inputs["prompt_text"] = prompt_text + inputs["instruction_text"] = input_ + return t.cast(t.Dict[str, t.Any], inputs) + def _forward(self, input_tensors: dict[str, t.Any], **generate_kwargs: t.Any) -> transformers.utils.generic.ModelOutput: + if t.TYPE_CHECKING: assert self.tokenizer is not None + input_ids, attention_mask = input_tensors["input_ids"], input_tensors.get("attention_mask", None) + if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1 + else: in_b = input_ids.shape[0] + generated_sequence = self.model.generate(input_ids=input_ids.to(self.model.device) if input_ids is not None else None, attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None, pad_token_id=self.tokenizer.pad_token_id, **generate_kwargs) + out_b = generated_sequence.shape[0] + if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) + elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:])) + instruction_text = input_tensors.pop("instruction_text") + return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text} + def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False) -> list[dict[t.Literal["generated_text"], str]]: + if t.TYPE_CHECKING: assert self.tokenizer is not None + _generated_sequence, instruction_text = model_outputs["generated_sequence"][0], model_outputs["instruction_text"] + generated_sequence: list[list[int]] = _generated_sequence.numpy().tolist() + records: list[dict[t.Literal["generated_text"], str]] = [] + for sequence in generated_sequence: + # The response will be set to this variable if we can identify it. + decoded = None + # If we have token IDs for the response and end, then we can find the tokens and only decode between them. + if response_key_token_id and end_key_token_id: + # Find where "### Response:" is first found in the generated tokens. Considering this is part of the + # prompt, we should definitely find it. We will return the tokens found after this token. + try: response_pos = sequence.index(response_key_token_id) + except ValueError: response_pos = None + if response_pos is None: logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence) + if response_pos: + # Next find where "### End" is located. The model has been trained to end its responses with this + # sequence (or actually, the token ID it maps to, since it is a special token). We may not find + # this token, as the response could be truncated. If we don't find it then just return everything + # to the end. Note that even though we set eos_token_id, we still see the this token at the end. + try: end_pos = sequence.index(end_key_token_id) + except ValueError: end_pos = None + decoded = self.tokenizer.decode(sequence[response_pos + 1:end_pos]).strip() + if not decoded: + # Otherwise we'll decode everything and use a regex to find the response and end. + fully_decoded = self.tokenizer.decode(sequence) + # The response appears after "### Response:". The model has been trained to append "### End" at the + # end. + m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL) if m: decoded = m.group(1).strip() - else: logger.warning("Failed to find response in:\n%s", fully_decoded) - # If the full text is requested, then append the decoded text to the original instruction. - # This technically isn't the full text, as we format the instruction in the prompt the model has been - # trained on, but to the client it will appear to be the full text. - if return_full_text: decoded = f"{instruction_text}\n{decoded}" - records.append({"generated_text": t.cast(str, decoded)}) - return records + else: + # The model might not generate the "### End" sequence before reaching the max tokens. In this case, + # return everything after "### Response:". + m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL) + if m: decoded = m.group(1).strip() + else: logger.warning("Failed to find response in:\n%s", fully_decoded) + # If the full text is requested, then append the decoded text to the original instruction. + # This technically isn't the full text, as we format the instruction in the prompt the model has been + # trained on, but to the client it will appear to be the full text. + if return_full_text: decoded = f"{instruction_text}\n{decoded}" + records.append({"generated_text": t.cast(str, decoded)}) + return records + return InstructionTextGenerationPipeline() if _init else InstructionTextGenerationPipeline class DollyV2(openllm.LLM["transformers.Pipeline", "transformers.PreTrainedTokenizer"]): __openllm_internal__ = True @property def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.bfloat16}, {} - def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline: return InstructionTextGenerationPipeline(transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), self.tokenizer, return_full_text=self.config.return_full_text) + def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline: return get_pipeline(transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), self.tokenizer, _init=True, return_full_text=self.config.return_full_text) def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "top_p": top_p, "temperature": temperature, **attrs}, {} def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal["generated_text"], str]], **_: t.Any) -> str: return generation_result[0]["generated_text"] def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal["generated_text"], str]]: diff --git a/src/openllm/models/dolly_v2/modeling_vllm_dolly_v2.py b/src/openllm/models/dolly_v2/modeling_vllm_dolly_v2.py index 145686fa..3694ae08 100644 --- a/src/openllm/models/dolly_v2/modeling_vllm_dolly_v2.py +++ b/src/openllm/models/dolly_v2/modeling_vllm_dolly_v2.py @@ -1,12 +1,7 @@ from __future__ import annotations -import logging -import typing as t - -import openllm +import logging, typing as t, openllm from openllm._prompt import process_prompt - from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import vllm, transformers logger = logging.getLogger(__name__) diff --git a/src/openllm/models/falcon/__init__.py b/src/openllm/models/falcon/__init__.py index c2555b89..bfc5341b 100644 --- a/src/openllm/models/falcon/__init__.py +++ b/src/openllm/models/falcon/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations import sys, typing as t - from openllm.exceptions import MissingDependencyError from openllm.utils import LazyModule, is_torch_available diff --git a/src/openllm/models/falcon/configuration_falcon.py b/src/openllm/models/falcon/configuration_falcon.py index c0c3a766..6b1c90a2 100644 --- a/src/openllm/models/falcon/configuration_falcon.py +++ b/src/openllm/models/falcon/configuration_falcon.py @@ -1,5 +1,4 @@ from __future__ import annotations - import openllm class FalconConfig(openllm.LLMConfig): diff --git a/src/openllm/models/falcon/modeling_falcon.py b/src/openllm/models/falcon/modeling_falcon.py index 49b0bc83..1351744c 100644 --- a/src/openllm/models/falcon/modeling_falcon.py +++ b/src/openllm/models/falcon/modeling_falcon.py @@ -1,23 +1,19 @@ from __future__ import annotations -import sys, typing as t - -import openllm +import typing as t, openllm from openllm._prompt import process_prompt - from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import torch, transformers else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers") class Falcon(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerBase"]): __openllm_internal__ = True @property - def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.bfloat16, "device_map": "auto" if torch.cuda.is_available() and torch.cuda_device_count() > 1 else None}, {} + def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.bfloat16, "device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {} def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, top_k: int | None = None, num_return_sequences: int | None = None, eos_token_id: int | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_k": top_k, "num_return_sequences": num_return_sequences, "eos_token_id": eos_token_id, **attrs}, {} def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0] def generate(self, prompt: str, **attrs: t.Any) -> list[str]: eos_token_id, inputs = attrs.pop("eos_token_id", self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors="pt").to(self.device) - with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): + with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): # type: ignore[attr-defined] return self.tokenizer.batch_decode(self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], generation_config=self.config.model_construct_env(eos_token_id=eos_token_id, **attrs).to_generation_config()), skip_special_tokens=True) def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]: max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device) diff --git a/src/openllm/models/flan_t5/__init__.py b/src/openllm/models/flan_t5/__init__.py index 3b4579cd..189bcc10 100644 --- a/src/openllm/models/flan_t5/__init__.py +++ b/src/openllm/models/flan_t5/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations import sys, typing as t - from openllm.exceptions import MissingDependencyError from openllm.utils import LazyModule, is_flax_available, is_tf_available, is_torch_available diff --git a/src/openllm/models/flan_t5/configuration_flan_t5.py b/src/openllm/models/flan_t5/configuration_flan_t5.py index 5b9878cc..e0cd167d 100644 --- a/src/openllm/models/flan_t5/configuration_flan_t5.py +++ b/src/openllm/models/flan_t5/configuration_flan_t5.py @@ -1,5 +1,4 @@ from __future__ import annotations - import openllm class FlanT5Config(openllm.LLMConfig): diff --git a/src/openllm/models/flan_t5/modeling_flan_t5.py b/src/openllm/models/flan_t5/modeling_flan_t5.py index 5d310c38..be61ec1f 100644 --- a/src/openllm/models/flan_t5/modeling_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_flan_t5.py @@ -1,11 +1,7 @@ from __future__ import annotations -import sys, typing as t - -import openllm +import typing as t, openllm from openllm._prompt import process_prompt - from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional") diff --git a/src/openllm/models/flan_t5/modeling_flax_flan_t5.py b/src/openllm/models/flan_t5/modeling_flax_flan_t5.py index 70bd7681..537d6f27 100644 --- a/src/openllm/models/flan_t5/modeling_flax_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_flax_flan_t5.py @@ -1,11 +1,7 @@ from __future__ import annotations -import sys, typing as t - -import openllm +import typing as t, openllm from openllm._prompt import process_prompt - from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import transformers class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "transformers.T5TokenizerFast"]): diff --git a/src/openllm/models/flan_t5/modeling_tf_flan_t5.py b/src/openllm/models/flan_t5/modeling_tf_flan_t5.py index d148b2a8..a3dfaba6 100644 --- a/src/openllm/models/flan_t5/modeling_tf_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_tf_flan_t5.py @@ -1,11 +1,7 @@ from __future__ import annotations -import sys, typing as t - -import openllm +import typing as t, openllm from openllm._prompt import process_prompt - from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import transformers class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transformers.T5TokenizerFast"]): diff --git a/src/openllm/models/gpt_neox/__init__.py b/src/openllm/models/gpt_neox/__init__.py index be79ebb4..dbf164c7 100644 --- a/src/openllm/models/gpt_neox/__init__.py +++ b/src/openllm/models/gpt_neox/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations import sys, typing as t - from openllm.exceptions import MissingDependencyError from openllm.utils import LazyModule, is_torch_available, is_vllm_available diff --git a/src/openllm/models/gpt_neox/configuration_gpt_neox.py b/src/openllm/models/gpt_neox/configuration_gpt_neox.py index d59012f2..8346d05b 100644 --- a/src/openllm/models/gpt_neox/configuration_gpt_neox.py +++ b/src/openllm/models/gpt_neox/configuration_gpt_neox.py @@ -1,5 +1,4 @@ from __future__ import annotations - import openllm class GPTNeoXConfig(openllm.LLMConfig): diff --git a/src/openllm/models/gpt_neox/modeling_gpt_neox.py b/src/openllm/models/gpt_neox/modeling_gpt_neox.py index 65529f3c..e0deff47 100644 --- a/src/openllm/models/gpt_neox/modeling_gpt_neox.py +++ b/src/openllm/models/gpt_neox/modeling_gpt_neox.py @@ -1,16 +1,11 @@ from __future__ import annotations -import logging -import typing as t - -import openllm +import logging, typing as t, openllm from openllm._prompt import process_prompt - from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import torch, transformers else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers") -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNeoXTokenizerFast"]): __openllm_internal__ = True def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature}, {} diff --git a/src/openllm/models/gpt_neox/modeling_vllm_gpt_neox.py b/src/openllm/models/gpt_neox/modeling_vllm_gpt_neox.py index 65a617a0..8582b575 100644 --- a/src/openllm/models/gpt_neox/modeling_vllm_gpt_neox.py +++ b/src/openllm/models/gpt_neox/modeling_vllm_gpt_neox.py @@ -1,13 +1,7 @@ - from __future__ import annotations -import logging -import typing as t - -import openllm +import typing as t, openllm, logging from openllm._prompt import process_prompt - from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import vllm, transformers logger = logging.getLogger(__name__) diff --git a/src/openllm/models/llama/__init__.py b/src/openllm/models/llama/__init__.py index c52ebca2..a630485a 100644 --- a/src/openllm/models/llama/__init__.py +++ b/src/openllm/models/llama/__init__.py @@ -1,7 +1,5 @@ from __future__ import annotations -import sys -import typing as t - +import sys, typing as t from openllm.exceptions import MissingDependencyError from openllm.utils import LazyModule, is_torch_available, is_vllm_available diff --git a/src/openllm/models/llama/configuration_llama.py b/src/openllm/models/llama/configuration_llama.py index 99748c1b..10dc5b31 100644 --- a/src/openllm/models/llama/configuration_llama.py +++ b/src/openllm/models/llama/configuration_llama.py @@ -1,7 +1,5 @@ from __future__ import annotations -import sys, typing as t - -import openllm +import typing as t, openllm class LlamaConfig(openllm.LLMConfig): """LLaMA model was proposed in [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. diff --git a/src/openllm/models/llama/modeling_llama.py b/src/openllm/models/llama/modeling_llama.py index b5a5e0d8..e7915ee1 100644 --- a/src/openllm/models/llama/modeling_llama.py +++ b/src/openllm/models/llama/modeling_llama.py @@ -1,18 +1,15 @@ from __future__ import annotations -import logging -import typing as t - -import openllm +import logging, typing as t, openllm from openllm._prompt import process_prompt - from .configuration_llama import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import torch, transformers, torch.nn.functional as F else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional") -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class Llama(openllm.LLM["transformers.LlamaForCausalLM", "transformers.LlamaTokenizerFast"]): __openllm_internal__ = True + @property + def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {} def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {} def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str: return generation_result[0] def generate(self, prompt: str, **attrs: t.Any) -> list[str]: @@ -25,4 +22,4 @@ class Llama(openllm.LLM["transformers.LlamaForCausalLM", "transformers.LlamaToke mask = attention_mask.unsqueeze(-1).expand(data.size()).float() masked_embeddings = data * mask sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1) - return openllm.LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=torch.sum(attention_mask).item()) + return openllm.LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=int(torch.sum(attention_mask).item())) diff --git a/src/openllm/models/llama/modeling_vllm_llama.py b/src/openllm/models/llama/modeling_vllm_llama.py index 0f05eab5..11981597 100644 --- a/src/openllm/models/llama/modeling_vllm_llama.py +++ b/src/openllm/models/llama/modeling_vllm_llama.py @@ -1,15 +1,10 @@ from __future__ import annotations -import logging -import typing as t - -import openllm +import logging, typing as t, openllm from openllm._prompt import process_prompt - from .configuration_llama import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import vllm, transformers -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class VLLMLlama(openllm.LLM["vllm.LLMEngine", "transformers.LlamaTokenizerFast"]): __openllm_internal__ = True def sanitize_parameters(self, prompt: str, top_k: int | None = None, top_p: float | None = None, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = False, use_llama2_prompt: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k}, {} diff --git a/src/openllm/models/mpt/__init__.py b/src/openllm/models/mpt/__init__.py index d12b25ca..9f7fe1f5 100644 --- a/src/openllm/models/mpt/__init__.py +++ b/src/openllm/models/mpt/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations import sys, typing as t - from openllm.exceptions import MissingDependencyError from openllm.utils import LazyModule, is_torch_available, is_vllm_available diff --git a/src/openllm/models/mpt/modeling_mpt.py b/src/openllm/models/mpt/modeling_mpt.py index fb4f2f60..214cfb28 100644 --- a/src/openllm/models/mpt/modeling_mpt.py +++ b/src/openllm/models/mpt/modeling_mpt.py @@ -1,18 +1,13 @@ from __future__ import annotations -import logging -import typing as t - -import bentoml -import openllm +import logging, typing as t, bentoml, openllm from openllm._prompt import process_prompt from openllm.utils import generate_labels, is_triton_available - from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE, MPTPromptType if t.TYPE_CHECKING: import transformers, torch else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch") -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) def get_mpt_config(model_id_or_path: str, max_sequence_length: int, device: torch.device | str | int | None, device_map: str | None = None, trust_remote_code: bool = True) -> transformers.PretrainedConfig: config = transformers.AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) if hasattr(config, "init_device") and device_map is None and isinstance(device, (str, torch.device)): config.init_device = str(device) @@ -62,7 +57,7 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken attrs = {"do_sample": False if llm_config["temperature"] == 0 else True, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, "generation_config": llm_config.to_generation_config()} with torch.inference_mode(): if torch.cuda.is_available(): - with torch.autocast("cuda", torch.float16): + with torch.autocast("cuda", torch.float16): # type: ignore[attr-defined] generated_tensors = self.model.generate(**inputs, **attrs) else: generated_tensors = self.model.generate(**inputs, **attrs) return self.tokenizer.batch_decode(generated_tensors, skip_special_tokens=True) diff --git a/src/openllm/models/mpt/modeling_vllm_mpt.py b/src/openllm/models/mpt/modeling_vllm_mpt.py index b9046da6..35ba16ba 100644 --- a/src/openllm/models/mpt/modeling_vllm_mpt.py +++ b/src/openllm/models/mpt/modeling_vllm_mpt.py @@ -1,28 +1,8 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations -import logging -import typing as t - -import openllm +import logging, typing as t, openllm from openllm._prompt import process_prompt - from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE, MPTPromptType - -if t.TYPE_CHECKING: - import transformers - import vllm +if t.TYPE_CHECKING: import transformers, vllm logger = logging.getLogger(__name__) class VLLMMPT(openllm.LLM["vllm.LLMEngine", "transformers.GPTNeoXTokenizerFast"]): diff --git a/src/openllm/models/opt/__init__.py b/src/openllm/models/opt/__init__.py index eaebdf4b..ebd225e4 100644 --- a/src/openllm/models/opt/__init__.py +++ b/src/openllm/models/opt/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations import sys, typing as t - from openllm.exceptions import MissingDependencyError from openllm.utils import LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vllm_available diff --git a/src/openllm/models/opt/configuration_opt.py b/src/openllm/models/opt/configuration_opt.py index b0f216c0..0238445f 100644 --- a/src/openllm/models/opt/configuration_opt.py +++ b/src/openllm/models/opt/configuration_opt.py @@ -1,5 +1,4 @@ from __future__ import annotations - import openllm class OPTConfig(openllm.LLMConfig): diff --git a/src/openllm/models/opt/modeling_flax_opt.py b/src/openllm/models/opt/modeling_flax_opt.py index e54821a8..089959fd 100644 --- a/src/openllm/models/opt/modeling_flax_opt.py +++ b/src/openllm/models/opt/modeling_flax_opt.py @@ -1,14 +1,8 @@ from __future__ import annotations -import logging -import typing as t - -import bentoml -import openllm +import logging, typing as t, bentoml, openllm from openllm._prompt import process_prompt from openllm.utils import generate_labels - from .configuration_opt import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import transformers else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") diff --git a/src/openllm/models/opt/modeling_opt.py b/src/openllm/models/opt/modeling_opt.py index 43781609..9c9456e9 100644 --- a/src/openllm/models/opt/modeling_opt.py +++ b/src/openllm/models/opt/modeling_opt.py @@ -1,12 +1,7 @@ from __future__ import annotations -import logging -import typing as t - -import openllm +import logging, typing as t, openllm from openllm._prompt import process_prompt - from .configuration_opt import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import torch, transformers else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers") diff --git a/src/openllm/models/opt/modeling_tf_opt.py b/src/openllm/models/opt/modeling_tf_opt.py index 9b057815..a53dd871 100644 --- a/src/openllm/models/opt/modeling_tf_opt.py +++ b/src/openllm/models/opt/modeling_tf_opt.py @@ -1,18 +1,12 @@ from __future__ import annotations -import logging -import typing as t - -import bentoml -import openllm +import logging, typing as t, bentoml, openllm from openllm._prompt import process_prompt from openllm.utils import generate_labels - from .configuration_opt import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import transformers else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]): __openllm_internal__ = True def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model: diff --git a/src/openllm/models/opt/modeling_vllm_opt.py b/src/openllm/models/opt/modeling_vllm_opt.py index 5b6bbe51..6591ae5e 100644 --- a/src/openllm/models/opt/modeling_vllm_opt.py +++ b/src/openllm/models/opt/modeling_vllm_opt.py @@ -1,12 +1,7 @@ from __future__ import annotations -import logging -import typing as t - -import openllm +import logging, typing as t, openllm from openllm._prompt import process_prompt - from .configuration_opt import DEFAULT_PROMPT_TEMPLATE - if t.TYPE_CHECKING: import vllm, transformers logger = logging.getLogger(__name__) diff --git a/src/openllm/models/stablelm/__init__.py b/src/openllm/models/stablelm/__init__.py index de41ab31..2927aaec 100644 --- a/src/openllm/models/stablelm/__init__.py +++ b/src/openllm/models/stablelm/__init__.py @@ -1,7 +1,5 @@ from __future__ import annotations -import sys -import typing as t - +import sys, typing as t from openllm.exceptions import MissingDependencyError from openllm.utils import LazyModule, is_torch_available, is_vllm_available diff --git a/src/openllm/models/stablelm/configuration_stablelm.py b/src/openllm/models/stablelm/configuration_stablelm.py index cbae8978..553cef61 100644 --- a/src/openllm/models/stablelm/configuration_stablelm.py +++ b/src/openllm/models/stablelm/configuration_stablelm.py @@ -1,5 +1,4 @@ from __future__ import annotations - import openllm class StableLMConfig(openllm.LLMConfig): diff --git a/src/openllm/models/stablelm/modeling_stablelm.py b/src/openllm/models/stablelm/modeling_stablelm.py index fcf3d491..ddeb6cb3 100644 --- a/src/openllm/models/stablelm/modeling_stablelm.py +++ b/src/openllm/models/stablelm/modeling_stablelm.py @@ -1,10 +1,6 @@ from __future__ import annotations -import logging -import typing as t - -import openllm +import logging, typing as t, openllm from openllm._prompt import process_prompt - from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE, SYSTEM_PROMPT if t.TYPE_CHECKING: import transformers, torch diff --git a/src/openllm/models/stablelm/modeling_vllm_stablelm.py b/src/openllm/models/stablelm/modeling_vllm_stablelm.py index 8e8803ff..56eddfe5 100644 --- a/src/openllm/models/stablelm/modeling_vllm_stablelm.py +++ b/src/openllm/models/stablelm/modeling_vllm_stablelm.py @@ -1,25 +1,7 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations -import logging -import typing as t - -import openllm +import logging, typing as t, openllm from openllm._prompt import process_prompt - from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE, SYSTEM_PROMPT - if t.TYPE_CHECKING: import vllm, transformers logger = logging.getLogger(__name__) diff --git a/src/openllm/models/starcoder/__init__.py b/src/openllm/models/starcoder/__init__.py index d8c00824..6cc2c524 100644 --- a/src/openllm/models/starcoder/__init__.py +++ b/src/openllm/models/starcoder/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations import sys, typing as t - from openllm.exceptions import MissingDependencyError from openllm.utils import LazyModule, is_torch_available, is_vllm_available diff --git a/src/openllm/models/starcoder/modeling_starcoder.py b/src/openllm/models/starcoder/modeling_starcoder.py index c98c7f10..4abb20ab 100644 --- a/src/openllm/models/starcoder/modeling_starcoder.py +++ b/src/openllm/models/starcoder/modeling_starcoder.py @@ -1,17 +1,11 @@ from __future__ import annotations -import logging -import typing as t - -import bentoml -import openllm +import logging, typing as t, bentoml, openllm from openllm.utils import generate_labels - from .configuration_starcoder import EOD, FIM_INDICATOR, FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX - if t.TYPE_CHECKING: import torch, transformers else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers") -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.GPT2TokenizerFast"]): __openllm_internal__ = True @property @@ -33,7 +27,6 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers. # XXX: This value for pad_token_id is currently a hack, need more investigate why the # default starcoder doesn't include the same value as santacoder EOD return prompt_text, {"temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "pad_token_id": 49152, **attrs}, {} - def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0] def generate(self, prompt: str, **attrs: t.Any) -> list[str]: with torch.inference_mode(): diff --git a/src/openllm/models/starcoder/modeling_vllm_starcoder.py b/src/openllm/models/starcoder/modeling_vllm_starcoder.py index c4e19cfa..b54aa63d 100644 --- a/src/openllm/models/starcoder/modeling_vllm_starcoder.py +++ b/src/openllm/models/starcoder/modeling_vllm_starcoder.py @@ -1,27 +1,9 @@ -# Copyright 2023 BentoML Team. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations -import logging -import typing as t - -import openllm - +import logging, typing as t, openllm from .configuration_starcoder import EOD, FIM_INDICATOR, FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX - if t.TYPE_CHECKING: import vllm, transformers logger = logging.getLogger(__name__) - class VLLMStarCoder(openllm.LLM["vllm.LLMEngine", "transformers.GPT2TokenizerFast"]): __openllm_internal__ = True tokenizer_id = "local" diff --git a/src/openllm/serialisation/__init__.py b/src/openllm/serialisation/__init__.py index 95e8d78c..495f85c4 100644 --- a/src/openllm/serialisation/__init__.py +++ b/src/openllm/serialisation/__init__.py @@ -23,24 +23,20 @@ llm.save_pretrained("./path/to/local-dolly") ``` """ from __future__ import annotations -import importlib -import typing as t - -import cloudpickle -import fs - -import openllm +import importlib, typing as t +import cloudpickle, fs, openllm from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME +from openllm._typing_compat import M, T, ParamSpec, Concatenate if t.TYPE_CHECKING: - from openllm._llm import T - + import bentoml from . import ( constants as constants, ggml as ggml, transformers as transformers, ) +P = ParamSpec("P") def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T: """Load the tokenizer from BentoML store. @@ -67,9 +63,8 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T: return tokenizer _extras = ["get", "import_model", "save_pretrained", "load_model"] - -def _make_dispatch_function(fn: str) -> t.Callable[..., t.Any]: - def caller(llm: openllm.LLM[t.Any, t.Any], *args: t.Any, **kwargs: t.Any) -> t.Any: +def _make_dispatch_function(fn: str) -> t.Callable[Concatenate[openllm.LLM[t.Any, t.Any], P], t.Any]: + def caller(llm: openllm.LLM[t.Any, t.Any], *args: P.args, **kwargs: P.kwargs) -> t.Any: """Generic function dispatch to correct serialisation submodules based on LLM runtime. > [!NOTE] See 'openllm.serialisation.transformers' if 'llm.runtime="transformers"' @@ -77,11 +72,15 @@ def _make_dispatch_function(fn: str) -> t.Callable[..., t.Any]: > [!NOTE] See 'openllm.serialisation.ggml' if 'llm.runtime="ggml"' """ return getattr(importlib.import_module(f".{llm.runtime}", __name__), fn)(llm, *args, **kwargs) - return caller -_import_structure: dict[str, list[str]] = {"ggml": [], "transformers": [], "constants": []} +if t.TYPE_CHECKING: + def get(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ... + def import_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ... + def save_pretrained(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> None: ... + def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M: ... +_import_structure: dict[str, list[str]] = {"ggml": [], "transformers": [], "constants": []} __all__ = ["ggml", "transformers", "constants", "load_tokenizer", *_extras] def __dir__() -> list[str]: return sorted(__all__) def __getattr__(name: str) -> t.Any: diff --git a/src/openllm/serialisation/constants.py b/src/openllm/serialisation/constants.py index aa1368a1..4537d9a8 100644 --- a/src/openllm/serialisation/constants.py +++ b/src/openllm/serialisation/constants.py @@ -1,10 +1,3 @@ from __future__ import annotations - -FRAMEWORK_TO_AUTOCLASS_MAPPING = { - "pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), "tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"), "flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"), - # NOTE: vllm will use PyTorch implementation of transformers for serialisation - "vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM") -} - -# this logic below is synonymous to handling `from_pretrained` attrs. +FRAMEWORK_TO_AUTOCLASS_MAPPING = {"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"), "tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"), "flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"), "vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM")} HUB_ATTRS = ["cache_dir", "code_revision", "force_download", "local_files_only", "proxies", "resume_download", "revision", "subfolder", "use_auth_token"] diff --git a/src/openllm/serialisation/ggml.py b/src/openllm/serialisation/ggml.py index e18a83a2..c7dc1ffb 100644 --- a/src/openllm/serialisation/ggml.py +++ b/src/openllm/serialisation/ggml.py @@ -4,18 +4,13 @@ This requires ctransformers to be installed. """ from __future__ import annotations import typing as t +import bentoml, openllm -import bentoml -import openllm - -if t.TYPE_CHECKING: - from openllm._llm import M +if t.TYPE_CHECKING: from openllm._typing_compat import M _conversion_strategy = {"pt": "ggml"} -def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model: - raise NotImplementedError("Currently work in progress.") - +def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model: raise NotImplementedError("Currently work in progress.") def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model: """Return an instance of ``bentoml.Model`` from given LLM instance. @@ -35,14 +30,5 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo if auto_import: return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__) raise - -def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M: - """Load the model from BentoML store. - - By default, it will try to find check the model in the local store. - If model is not found, it will raises a ``bentoml.exceptions.NotFound``. - """ - raise NotImplementedError("Currently work in progress.") - -def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any) -> None: - raise NotImplementedError("Currently work in progress.") +def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M: raise NotImplementedError("Currently work in progress.") +def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any) -> None: raise NotImplementedError("Currently work in progress.") diff --git a/src/openllm/serialisation/transformers/__init__.py b/src/openllm/serialisation/transformers/__init__.py index 88715935..5ab9148d 100644 --- a/src/openllm/serialisation/transformers/__init__.py +++ b/src/openllm/serialisation/transformers/__init__.py @@ -1,19 +1,12 @@ -# mypy: disable-error-code="name-defined,misc" """Serialisation related implementation for Transformers-based implementation.""" from __future__ import annotations -import importlib -import logging -import typing as t - +import importlib, logging, typing as t +import bentoml, openllm from huggingface_hub import snapshot_download from simple_di import Provide, inject - -import bentoml -import openllm from bentoml._internal.configuration.containers import BentoMLContainer from bentoml._internal.models.model import ModelOptions -from openllm.serialisation.transformers.weights import HfIgnore - +from .weights import HfIgnore from ._helpers import ( check_unintialised_params, infer_autoclass_from_llm, @@ -26,16 +19,16 @@ from ._helpers import ( if t.TYPE_CHECKING: import types + import vllm, auto_gptq as autogptq, transformers ,torch import torch.nn from bentoml._internal.models import ModelStore - from openllm._llm import M, T - from openllm._types import DictStrAny - -vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm") -autogptq = openllm.utils.LazyLoader("autogptq", globals(), "auto_gptq") -transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") -torch = openllm.utils.LazyLoader("torch", globals(), "torch") + from openllm._typing_compat import DictStrAny, M, T +else: + vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm") + autogptq = openllm.utils.LazyLoader("autogptq", globals(), "auto_gptq") + transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") + torch = openllm.utils.LazyLoader("torch", globals(), "torch") logger = logging.getLogger(__name__) diff --git a/src/openllm/serialisation/transformers/_helpers.py b/src/openllm/serialisation/transformers/_helpers.py index 2b7ca4ff..92f0d4a3 100644 --- a/src/openllm/serialisation/transformers/_helpers.py +++ b/src/openllm/serialisation/transformers/_helpers.py @@ -1,25 +1,14 @@ - from __future__ import annotations -import copy -import typing as t - -import openllm +import copy, typing as t, openllm from bentoml._internal.models.model import ModelInfo, ModelSignature from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING, HUB_ATTRS if t.TYPE_CHECKING: - import torch - import transformers + import torch, transformers, bentoml from transformers.models.auto.auto_factory import _BaseAutoModelClass - - import bentoml from bentoml._internal.models.model import ModelSignaturesType - - from ..._llm import M, T - from ..._types import DictStrAny -else: - transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") - torch = openllm.utils.LazyLoader("torch", globals(), "torch") + from openllm._typing_compat import DictStrAny, M, T +else: transformers, torch = openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("torch", globals(), "torch") _object_setattr = object.__setattr__ diff --git a/src/openllm/serialisation/transformers/weights.py b/src/openllm/serialisation/transformers/weights.py index e18800d6..ee5372e2 100644 --- a/src/openllm/serialisation/transformers/weights.py +++ b/src/openllm/serialisation/transformers/weights.py @@ -1,22 +1,17 @@ from __future__ import annotations -import typing as t - -import attr +import typing as t, attr from huggingface_hub import HfApi - if t.TYPE_CHECKING: import openllm - from openllm._llm import M, T + from openllm._typing_compat import M, T def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool: return any(s.rfilename.endswith(".safetensors") for s in HfApi().model_info(model_id, revision=revision).siblings) - @attr.define(slots=True) class HfIgnore: safetensors = "*.safetensors" pt = "*.bin" tf = "*.h5" flax = "*.msgpack" - @classmethod def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]: if llm.__llm_implementation__ == "vllm": base = [cls.tf, cls.flax, cls.safetensors] diff --git a/src/openllm/testing.py b/src/openllm/testing.py index b1a1f717..ad2f54cc 100644 --- a/src/openllm/testing.py +++ b/src/openllm/testing.py @@ -1,37 +1,24 @@ """Tests utilities for OpenLLM.""" - from __future__ import annotations -import contextlib -import logging -import shutil -import subprocess -import typing as t - -import bentoml -import openllm - -if t.TYPE_CHECKING: - from ._configuration import LiteralRuntime +import contextlib, logging, shutil, subprocess, typing as t, bentoml, openllm +if t.TYPE_CHECKING: from ._typing_compat import LiteralRuntime logger = logging.getLogger(__name__) - @contextlib.contextmanager -def build_bento(model: str, model_id: str | None = None, quantize: t.Literal["int4", "int8", "gptq"] | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", cleanup: bool = False,) -> t.Iterator[bentoml.Bento]: +def build_bento(model: str, model_id: str | None = None, quantize: t.Literal["int4", "int8", "gptq"] | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", cleanup: bool = False) -> t.Iterator[bentoml.Bento]: logger.info("Building BentoML for %s", model) bento = openllm.build(model, model_id=model_id, quantize=quantize, runtime=runtime) yield bento if cleanup: logger.info("Deleting %s", bento.tag) bentoml.bentos.delete(bento.tag) - @contextlib.contextmanager -def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any,) -> t.Iterator[str]: +def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any) -> t.Iterator[str]: if isinstance(bento, bentoml.Bento): bento_tag = bento.tag else: bento_tag = bentoml.Tag.from_taglike(bento) if image_tag is None: image_tag = str(bento_tag) executable = shutil.which("docker") if not executable: raise RuntimeError("docker executable not found") - try: logger.info("Building container for %s", bento_tag) bentoml.container.build(bento_tag, backend="docker", image_tag=(image_tag,), progress="plain", **attrs,) @@ -40,19 +27,15 @@ def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | N if cleanup: logger.info("Deleting container %s", image_tag) subprocess.check_output([executable, "rmi", "-f", image_tag]) - @contextlib.contextmanager -def prepare(model: str, model_id: str | None = None, implementation: LiteralRuntime = "pt", deployment_mode: t.Literal["container", "local"] = "local", clean_context: contextlib.ExitStack | None = None, cleanup: bool = True,) -> t.Iterator[str]: +def prepare(model: str, model_id: str | None = None, implementation: LiteralRuntime = "pt", deployment_mode: t.Literal["container", "local"] = "local", clean_context: contextlib.ExitStack | None = None, cleanup: bool = True) -> t.Iterator[str]: if clean_context is None: clean_context = contextlib.ExitStack() cleanup = True - llm = openllm.infer_auto_class(implementation).for_model(model, model_id=model_id, ensure_available=True) bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}") - if not bentoml.list(bento_tag): bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup)) else: bento = bentoml.get(bento_tag) - container_name = f"openllm-{model}-{llm.llm_type}".replace("-", "_") if deployment_mode == "container": container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup)) yield container_name diff --git a/src/openllm/utils/__init__.py b/src/openllm/utils/__init__.py index 9cf41114..cc39040f 100644 --- a/src/openllm/utils/__init__.py +++ b/src/openllm/utils/__init__.py @@ -4,19 +4,9 @@ User can import these function for convenience, but we won't ensure backward compatibility for these functions. So use with caution. """ from __future__ import annotations -import contextlib -import functools -import hashlib -import logging -import logging.config -import os -import sys -import types -import typing as t +import contextlib, functools, hashlib, logging, logging.config, os, sys, types, typing as t, openllm from pathlib import Path - from circus.exc import ConflictError - from bentoml._internal.configuration import ( DEBUG_ENV_VAR as DEBUG_ENV_VAR, GRPC_DEBUG_ENV_VAR as _GRPC_DEBUG_ENV_VAR, @@ -41,20 +31,14 @@ from openllm.utils.lazy import ( VersionInfo as VersionInfo, ) -logger = logging.getLogger(__name__) +if t.TYPE_CHECKING: + from openllm._typing_compat import AnyCallable, LiteralRuntime +logger = logging.getLogger(__name__) try: from typing import GenericAlias as _TypingGenericAlias # type: ignore except ImportError: _TypingGenericAlias = () # type: ignore # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on) if sys.version_info < (3, 10): _WithArgsTypes = (_TypingGenericAlias,) else: _WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore # _GenericAlias is the actual GenericAlias implementation -# NOTE: We need to do this so that overload can register -# correct overloads to typing registry -if sys.version_info[:2] >= (3, 11): from typing import overload as _overload -else: from typing_extensions import overload as _overload -if t.TYPE_CHECKING: - import openllm - - from .._types import AnyCallable, LiteralRuntime DEV_DEBUG_VAR = "OPENLLMDEVDEBUG" @@ -194,9 +178,7 @@ def compose(*funcs: AnyCallable) -> AnyCallable: >>> [f(3*x, x+1) for x in range(1,10)] [1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7] """ - def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable: - return lambda *args, **kwargs: f1(f2(*args, **kwargs)) - + def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable: return lambda *args, **kwargs: f1(f2(*args, **kwargs)) return functools.reduce(compose_two, funcs) def apply(transform: AnyCallable) -> t.Callable[[AnyCallable], AnyCallable]: @@ -241,7 +223,6 @@ def resolve_filepath(path: str, ctx: str | None = None) -> str: def validate_is_path(maybe_path: str) -> bool: return os.path.exists(os.path.dirname(resolve_filepath(maybe_path))) def generate_context(framework_name: str) -> _ModelContext: - import openllm framework_versions = {"transformers": pkg.get_pkg_version("transformers")} if openllm.utils.is_torch_available(): framework_versions["torch"] = pkg.get_pkg_version("torch") if openllm.utils.is_tf_available(): @@ -261,14 +242,6 @@ def normalize_attrs_to_model_tokenizer_pair(**attrs: t.Any) -> tuple[dict[str, t if k.startswith(_TOKENIZER_PREFIX): del attrs[k] return attrs, tokenizer_attrs -@_overload -def infer_auto_class(implementation: t.Literal["pt"]) -> type[openllm.AutoLLM]: ... -@_overload -def infer_auto_class(implementation: t.Literal["tf"]) -> type[openllm.AutoTFLLM]: ... -@_overload -def infer_auto_class(implementation: t.Literal["flax"]) -> type[openllm.AutoFlaxLLM]: ... -@_overload -def infer_auto_class(implementation: t.Literal["vllm"]) -> type[openllm.AutoVLLM]: ... def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM] | type[openllm.AutoTFLLM] | type[openllm.AutoFlaxLLM] | type[openllm.AutoVLLM]: import openllm if implementation == "tf": return openllm.AutoTFLLM diff --git a/src/openllm/utils/analytics.py b/src/openllm/utils/analytics.py index 63246c0e..e19c0872 100644 --- a/src/openllm/utils/analytics.py +++ b/src/openllm/utils/analytics.py @@ -3,33 +3,17 @@ Users can disable this with OPENLLM_DO_NOT_TRACK envvar. """ from __future__ import annotations -import contextlib -import functools -import importlib.metadata -import logging -import os -import re -import sys -import typing as t - -import attr - -import openllm +import contextlib, functools, logging, os, re, typing as t, importlib.metadata +import attr, openllm from bentoml._internal.utils import analytics as _internal_analytics - -if sys.version_info[:2] >= (3, 10): - from typing import ParamSpec -else: - from typing_extensions import ParamSpec +from openllm._typing_compat import ParamSpec P = ParamSpec("P") T = t.TypeVar("T") - logger = logging.getLogger(__name__) # This variable is a proxy that will control BENTOML_DO_NOT_TRACK OPENLLM_DO_NOT_TRACK = "OPENLLM_DO_NOT_TRACK" - DO_NOT_TRACK = os.environ.get(OPENLLM_DO_NOT_TRACK, str(False)).upper() @functools.lru_cache(maxsize=1) @@ -75,18 +59,15 @@ class EventMeta: class ModelSaveEvent(EventMeta): module: str model_size_in_kb: float - @attr.define class OpenllmCliEvent(EventMeta): cmd_group: str cmd_name: str openllm_version: str = importlib.metadata.version("openllm") - # NOTE: reserved for the do_not_track logics duration_in_ms: t.Any = attr.field(default=None) error_type: str = attr.field(default=None) return_code: int = attr.field(default=None) - @attr.define class StartInitEvent(EventMeta): model_name: str diff --git a/src/openllm/utils/codegen.py b/src/openllm/utils/codegen.py index 4704dc60..49431d5c 100644 --- a/src/openllm/utils/codegen.py +++ b/src/openllm/utils/codegen.py @@ -1,22 +1,14 @@ from __future__ import annotations -import functools -import inspect -import linecache -import logging -import string -import types -import typing as t +import functools, inspect, linecache, os, logging, string, types, typing as t from operator import itemgetter from pathlib import Path - import orjson if t.TYPE_CHECKING: from fs.base import FS import openllm - - from .._types import AnyCallable, DictStrAny, ListStr + from openllm._typing_compat import LiteralString, AnyCallable, DictStrAny, ListStr PartialAny = functools.partial[t.Any] _T = t.TypeVar("_T", bound=t.Callable[..., t.Any]) @@ -24,7 +16,7 @@ logger = logging.getLogger(__name__) OPENLLM_MODEL_NAME = "# openllm: model name" OPENLLM_MODEL_ADAPTER_MAP = "# openllm: model adapter map" class ModelNameFormatter(string.Formatter): - model_keyword: t.LiteralString = "__model_name__" + model_keyword: LiteralString = "__model_name__" def __init__(self, model_name: str): """The formatter that extends model_name to be formatted the 'service.py'.""" super().__init__() @@ -36,14 +28,13 @@ class ModelNameFormatter(string.Formatter): return True except ValueError: return False class ModelIdFormatter(ModelNameFormatter): - model_keyword: t.LiteralString = "__model_id__" + model_keyword: LiteralString = "__model_id__" class ModelAdapterMapFormatter(ModelNameFormatter): - model_keyword: t.LiteralString = "__model_adapter_map__" + model_keyword: LiteralString = "__model_adapter_map__" -_service_file = Path(__file__).parent.parent / "_service.py" +_service_file = Path(os.path.abspath("__file__")).parent.parent/"_service.py" def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | None] | None, llm_fs: FS) -> None: - from . import DEBUG - + from openllm.utils import DEBUG model_name = llm.config["model_name"] logger.debug("Generating service file for %s at %s (dir=%s)", model_name, llm.config["service_name"], llm_fs.getsyspath("/")) with open(_service_file.__fspath__(), "r") as f: src_contents = f.readlines() @@ -119,33 +110,26 @@ def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t. return globs[attr_class_name] def generate_unique_filename(cls: type[t.Any], func_name: str) -> str: return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>" -def generate_function(typ: type[t.Any], func_name: str, lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None,) -> AnyCallable: - from . import SHOW_CODEGEN - +def generate_function(typ: type[t.Any], func_name: str, lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None) -> AnyCallable: + from openllm.utils import SHOW_CODEGEN script = "def %s(%s):\n %s\n" % (func_name, ", ".join(args) if args is not None else "", "\n ".join(lines) if lines else "pass") meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs) if annotations: meth.__annotations__ = annotations if SHOW_CODEGEN: logger.info("Generated script for %s:\n\n%s", typ, script) return meth -def make_env_transformer(cls: type[openllm.LLMConfig], model_name: str, suffix: t.LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None,) -> AnyCallable: - from . import dantic, field_env_key - +def make_env_transformer(cls: type[openllm.LLMConfig], model_name: str, suffix: LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None,) -> AnyCallable: + from openllm.utils import dantic, field_env_key def identity(_: str, x_value: t.Any) -> t.Any: return x_value default_callback = identity if default_callback is None else default_callback - globs = {} if globs is None else globs globs.update({"__populate_env": dantic.env_converter, "__default_callback": default_callback, "__field_env": field_env_key, "__suffix": suffix or "", "__model_name": model_name,}) - lines: ListStr = ["__env = lambda field_name: __field_env(__model_name, field_name, __suffix)", "return [", " f.evolve(", " default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),", " metadata={", " 'env': f.metadata.get('env', __env(f.name)),", " 'description': f.metadata.get('description', '(not provided)'),", " },", " )", " for f in fields", "]"] fields_ann = "list[attr.Attribute[t.Any]]" - return generate_function(cls, "__auto_env", lines, args=("_", "fields"), globs=globs, annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann}) - def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T: """Enhance sdk with nice repr that plays well with your brain.""" - from .representation import ReprMixin - + from openllm.utils import ReprMixin if name is None: name = func.__name__.strip("_") _signatures = inspect.signature(func).parameters def _repr(self: ReprMixin) -> str: return f"" @@ -153,3 +137,5 @@ def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T: if func.__doc__ is None: doc = f"Generated SDK for {func.__name__}" else: doc = func.__doc__ return t.cast(_T, functools.update_wrapper(types.new_class(name, (t.cast("PartialAny", functools.partial), ReprMixin), exec_body=lambda ns: ns.update({"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]), "__repr_args__": _repr_args, "__repr__": _repr, "__doc__": inspect.cleandoc(doc), "__module__": "openllm",}),)(func, **attrs), func,)) + +__all__ = ["gen_sdk", "make_attr_tuple_class", "make_env_transformer", "generate_unique_filename", "generate_function", "OPENLLM_MODEL_NAME", "OPENLLM_MODEL_ADAPTER_MAP"] diff --git a/src/openllm/utils/dantic.py b/src/openllm/utils/dantic.py index add905f6..960620c3 100644 --- a/src/openllm/utils/dantic.py +++ b/src/openllm/utils/dantic.py @@ -1,30 +1,22 @@ """An interface provides the best of pydantic and attrs.""" from __future__ import annotations -import functools -import importlib -import os -import sys -import typing as t +import functools, importlib, os, sys, typing as t from enum import Enum - -import attr -import click -import click_option_group as cog -import inflection -import orjson +import attr, click, click_option_group as cog, inflection, orjson from click import ( ParamType, shell_completion as sc, types as click_types, ) -if t.TYPE_CHECKING: - from attr import _ValidatorType +if t.TYPE_CHECKING: from attr import _ValidatorType -_T = t.TypeVar("_T") AnyCallable = t.Callable[..., t.Any] FC = t.TypeVar("FC", bound=t.Union[AnyCallable, click.Command]) +__all__ = ["FC", "attrs_to_options", "Field", "parse_type", "is_typing", "is_literal", "ModuleType", "EnumChoice", "LiteralChoice", "allows_multiple", "is_mapping", "is_container", "parse_container_args", "parse_single_arg", "CUDA", "JsonType", "BytesType"] +def __dir__() -> list[str]: return sorted(__all__) + def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, typ: t.Any | None = None, suffix_generation: bool = False, suffix_sampling: bool = False,) -> t.Callable[[FC], FC]: # TODO: support parsing nested attrs class and Union envvar = field.metadata["env"] @@ -47,13 +39,11 @@ def env_converter(value: t.Any, env: str | None = None) -> t.Any: if env is not None: value = os.environ.get(env, value) if value is not None and isinstance(value, str): - try: - return orjson.loads(value.lower()) - except orjson.JSONDecodeError as err: - raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None + try: return orjson.loads(value.lower()) + except orjson.JSONDecodeError as err: raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None return value -def Field(default: t.Any = None, *, ge: int | float | None = None, le: int | float | None = None, validator: _ValidatorType[_T] | None = None, description: str | None = None, env: str | None = None, auto_default: bool = False, use_default_converter: bool = True, **attrs: t.Any,) -> t.Any: +def Field(default: t.Any = None, *, ge: int | float | None = None, le: int | float | None = None, validator: _ValidatorType[t.Any] | None = None, description: str | None = None, env: str | None = None, auto_default: bool = False, use_default_converter: bool = True, **attrs: t.Any) -> t.Any: """A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field. By default, if both validator and ge are provided, then then ge will be diff --git a/src/openllm/utils/import_utils.py b/src/openllm/utils/import_utils.py index 13855eea..90b148ca 100644 --- a/src/openllm/utils/import_utils.py +++ b/src/openllm/utils/import_utils.py @@ -1,34 +1,16 @@ """Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons.""" from __future__ import annotations -import importlib -import importlib.metadata -import importlib.util -import logging -import os -import sys -import typing as t -from abc import ABCMeta +import importlib, importlib.metadata, importlib.util, logging, os, abc, typing as t from collections import OrderedDict - -import inflection -from packaging import version - +import inflection, packaging.version from bentoml._internal.utils import LazyLoader, pkg +from openllm._typing_compat import overload, LiteralString from .representation import ReprMixin -# NOTE: We need to do this so that overload can register -# correct overloads to typing registry -if sys.version_info[:2] >= (3, 11): - from typing import overload -else: - from typing_extensions import overload - if t.TYPE_CHECKING: - BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]] - from .._types import LiteralRuntime -else: - BackendOrderredDict = OrderedDict + BackendOrderedDict = OrderedDict[str, t.Tuple[t.Callable[[], bool], str]] + from openllm._typing_compat import LiteralRuntime logger = logging.getLogger(__name__) OPTIONAL_DEPENDENCIES = {"opt", "flan-t5", "vllm", "fine-tune", "ggml", "agents", "openai", "playground", "gptq",} @@ -104,10 +86,10 @@ def is_tf_available() -> bool: try: _tf_version = importlib.metadata.version(_pkg) break - except importlib.metadata.PackageNotFoundError: pass + except importlib.metadata.PackageNotFoundError: pass # noqa: PERF203 # Ok to ignore here since we actually need to check for all possible tensorflow distribution. _tf_available = _tf_version is not None if _tf_available: - if _tf_version and version.parse(_tf_version) < version.parse("2"): + if _tf_version and packaging.version.parse(_tf_version) < packaging.version.parse("2"): logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version) _tf_available = False else: @@ -232,11 +214,13 @@ You can install it with pip: `pip install fairscale`. Please note that you may n your runtime after installation. """ -BACKENDS_MAPPING = BackendOrderredDict([("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)), ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)), - ("triton", (is_triton_available, TRITON_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), - ("xformers", (is_xformers_available, XFORMERS_IMPORT_ERROR)), ("fairscale", (is_fairscale_available, FAIRSCALE_IMPORT_ERROR))]) +BACKENDS_MAPPING: BackendOrderedDict = OrderedDict([("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)), ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)), + ("triton", (is_triton_available, TRITON_IMPORT_ERROR)), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), + ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), + ("xformers", (is_xformers_available, XFORMERS_IMPORT_ERROR)), ("fairscale", (is_fairscale_available, FAIRSCALE_IMPORT_ERROR))]) -class DummyMetaclass(ABCMeta): +class DummyMetaclass(abc.ABCMeta): """Metaclass for dummy object. It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class. @@ -258,19 +242,17 @@ def require_backends(o: t.Any, backends: t.MutableSequence[str]) -> None: if "torch" not in backends and is_torch_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name)) if "tf" not in backends and is_tf_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name)) if "flax" not in backends and is_flax_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name)) - checks = (BACKENDS_MAPPING[backend] for backend in backends) - failed = [msg.format(name) for available, msg in checks if not available()] + failed = [msg.format(name) for available, msg in (BACKENDS_MAPPING[backend] for backend in backends) if not available()] if failed: raise ImportError("".join(failed)) class EnvVarMixin(ReprMixin): model_name: str - if t.TYPE_CHECKING: - config: str - model_id: str - quantize: str - framework: str - bettertransformer: str - runtime: str + config: str + model_id: str + quantize: str + framework: str + bettertransformer: str + runtime: str @overload def __getitem__(self, item: t.Literal["config"]) -> str: ... @overload @@ -297,9 +279,9 @@ class EnvVarMixin(ReprMixin): if item.endswith("_value") and hasattr(self, f"_{item}"): return object.__getattribute__(self, f"_{item}")() elif hasattr(self, item): return getattr(self, item) raise KeyError(f"Key {item} not found in {self}") - def __init__(self, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: t.LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers") -> None: + def __init__(self, model_name: str, implementation: LiteralRuntime = "pt", model_id: str | None = None, bettertransformer: bool | None = None, quantize: LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers") -> None: """EnvVarMixin is a mixin class that returns the value extracted from environment variables.""" - from .._configuration import field_env_key + from openllm._configuration import field_env_key self.model_name = inflection.underscore(model_name) self._implementation = implementation self._model_id = model_id diff --git a/src/openllm/utils/lazy.py b/src/openllm/utils/lazy.py index f8084942..da8c46dd 100644 --- a/src/openllm/utils/lazy.py +++ b/src/openllm/utils/lazy.py @@ -1,19 +1,6 @@ from __future__ import annotations -import functools -import importlib -import importlib.machinery -import importlib.metadata -import importlib.util -import itertools -import os -import time -import types -import typing as t -import warnings - -import attr - -import openllm +import functools, importlib, importlib.machinery, importlib.metadata, importlib.util, itertools, os, time, types, warnings, typing as t +import attr, openllm __all__ = ["VersionInfo", "LazyModule"] # vendorred from attrs @@ -68,7 +55,7 @@ class LazyModule(types.ModuleType): for key, values in import_structure.items(): for value in values: self._class_to_module[value] = key # Needed for autocompletion in an IDE - self.__all__ = list(import_structure.keys()) + list(itertools.chain(*import_structure.values())) + self.__all__: list[str] = list(import_structure.keys()) + list(itertools.chain(*import_structure.values())) self.__file__ = module_file self.__spec__ = module_spec or importlib.util.find_spec(name) self.__path__ = [os.path.dirname(module_file)] @@ -93,7 +80,7 @@ class LazyModule(types.ModuleType): if name in dunder_to_metadata: if name not in {"__version_info__", "__copyright__", "__version__"}: warnings.warn(f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.", DeprecationWarning, stacklevel=2) meta = importlib.metadata.metadata("openllm") - project_url = dict(url.split(", ") for url in meta.get_all("Project-URL")) + project_url = dict(url.split(", ") for url in t.cast(t.List[str], meta.get_all("Project-URL"))) if name == "__license__": return "Apache-2.0" elif name == "__copyright__": return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al." elif name in ("__uri__", "__url__"): return project_url["GitHub"] diff --git a/src/openllm/utils/representation.py b/src/openllm/utils/representation.py index 79d1820c..f644b77f 100644 --- a/src/openllm/utils/representation.py +++ b/src/openllm/utils/representation.py @@ -1,57 +1,32 @@ - from __future__ import annotations import typing as t from abc import abstractmethod +import attr, orjson +from openllm import utils +if t.TYPE_CHECKING: from openllm._typing_compat import TypeAlias -import attr -import orjson - -if t.TYPE_CHECKING: - ReprArgs: t.TypeAlias = t.Iterable[tuple[str | None, t.Any]] - +ReprArgs: TypeAlias = t.Generator[t.Tuple[t.Optional[str], t.Any], None, None] class ReprMixin: - """This class display possible representation of given class. - - It can be used for implementing __rich_pretty__ and __pretty__ methods in the future. - Most subclass needs to implement a __repr_keys__ property. - - Based on the design from Pydantic. - The __repr__ will display the json representation of the object for easier interaction. - The __str__ will display either __attrs_repr__ or __repr_str__. - """ @property @abstractmethod - def __repr_keys__(self) -> set[str]: - """This can be overriden by base class using this mixin.""" + def __repr_keys__(self) -> set[str]: raise NotImplementedError + """This can be overriden by base class using this mixin.""" + def __repr__(self) -> str: return f"{self.__class__.__name__} {orjson.dumps({k: utils.bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}" + """The `__repr__` for any subclass of Mixin. - def __repr__(self) -> str: - """The `__repr__` for any subclass of Mixin. + It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict. + """ + def __str__(self) -> str: return self.__repr_str__(" ") + """The string representation of the given Mixin subclass. - It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict. - """ - from . import bentoml_cattr + It will contains all of the attributes from __repr_keys__ + """ + def __repr_name__(self) -> str: return self.__class__.__name__ + """Name of the instance's class, used in __repr__.""" + def __repr_str__(self, join_str: str) -> str: return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__()) + """To be used with __str__.""" + def __repr_args__(self) -> ReprArgs: return ((k, getattr(self, k)) for k in self.__repr_keys__) + """This can also be overriden by base class using this mixin. - serialized = {k: bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()} - return f"{self.__class__.__name__} {orjson.dumps(serialized, option=orjson.OPT_INDENT_2).decode()}" - - def __str__(self) -> str: - """The string representation of the given Mixin subclass. - - It will contains all of the attributes from __repr_keys__ - """ - return self.__repr_str__(" ") - - def __repr_name__(self) -> str: - """Name of the instance's class, used in __repr__.""" - return self.__class__.__name__ - - def __repr_str__(self, join_str: str) -> str: - """To be used with __str__.""" - return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__()) - - def __repr_args__(self) -> ReprArgs: - """This can also be overriden by base class using this mixin. - - By default it does a getattr of the current object from __repr_keys__. - """ - return ((k, getattr(self, k)) for k in self.__repr_keys__) + By default it does a getattr of the current object from __repr_keys__. + """ diff --git a/tests/__init__.py b/tests/__init__.py index 4030e737..dd602334 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,16 +1,3 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from __future__ import annotations import os diff --git a/tests/_strategies/__init__.py b/tests/_strategies/__init__.py index 3a2faba5..e69de29b 100644 --- a/tests/_strategies/__init__.py +++ b/tests/_strategies/__init__.py @@ -1,13 +0,0 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/_strategies/_configuration.py b/tests/_strategies/_configuration.py index 5cc51585..fecdce38 100644 --- a/tests/_strategies/_configuration.py +++ b/tests/_strategies/_configuration.py @@ -1,17 +1,3 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations import logging import typing as t diff --git a/tests/client_test.py b/tests/client_test.py deleted file mode 100644 index 65303897..00000000 --- a/tests/client_test.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import openllm - -def test_import_client(): - assert len(openllm.client.__all__) == 6 - assert all(hasattr(openllm.client, attr) for attr in ("AsyncGrpcClient", "GrpcClient", "AsyncHTTPClient", "HTTPClient", "BaseClient", "BaseAsyncClient")) diff --git a/tests/configuration_test.py b/tests/configuration_test.py index a6109985..6d40dcfa 100644 --- a/tests/configuration_test.py +++ b/tests/configuration_test.py @@ -1,19 +1,3 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""All configuration-related tests for openllm.LLMConfig. This will include testing -for ModelEnv construction and parsing environment variables. -""" from __future__ import annotations import contextlib import logging diff --git a/tests/conftest.py b/tests/conftest.py index d2d96c3a..1c04cdfb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,44 +1,15 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations -import itertools -import os -import typing as t - -import pytest - -import openllm - -if t.TYPE_CHECKING: - from openllm._types import LiteralRuntime +import itertools, os, typing as t, pytest, openllm +if t.TYPE_CHECKING: from openllm._configuration import LiteralRuntime _FRAMEWORK_MAPPING = {"flan_t5": "google/flan-t5-small", "opt": "facebook/opt-125m", "baichuan": "baichuan-inc/Baichuan-7B",} _PROMPT_MAPPING = {"qa": "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?",} - def parametrise_local_llm(model: str,) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]: - if model not in _FRAMEWORK_MAPPING: - pytest.skip(f"'{model}' is not yet supported in framework testing.") - + if model not in _FRAMEWORK_MAPPING: pytest.skip(f"'{model}' is not yet supported in framework testing.") runtime_impl: tuple[LiteralRuntime, ...] = tuple() - if model in openllm.MODEL_MAPPING_NAMES: - runtime_impl += ("pt",) - if model in openllm.MODEL_FLAX_MAPPING_NAMES: - runtime_impl += ("flax",) - if model in openllm.MODEL_TF_MAPPING_NAMES: - runtime_impl += ("tf",) - + if model in openllm.MODEL_MAPPING_NAMES: runtime_impl += ("pt",) + if model in openllm.MODEL_FLAX_MAPPING_NAMES: runtime_impl += ("flax",) + if model in openllm.MODEL_TF_MAPPING_NAMES: runtime_impl += ("tf",) for framework, prompt in itertools.product(runtime_impl, _PROMPT_MAPPING.keys()): llm = openllm.Runner(model, model_id=_FRAMEWORK_MAPPING[model], ensure_available=True, implementation=framework, init_local=True,) yield prompt, llm @@ -47,8 +18,6 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: if os.getenv("GITHUB_ACTIONS") is None: if "prompt" in metafunc.fixturenames and "llm" in metafunc.fixturenames: metafunc.parametrize("prompt,llm", [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])]) - def pytest_sessionfinish(session: pytest.Session, exitstatus: int): # If no tests are collected, pytest exists with code 5, which makes the CI fail. - if exitstatus == 5: - session.exitstatus = 0 + if exitstatus == 5: session.exitstatus = 0 diff --git a/tests/models/__init__.py b/tests/models/__init__.py index 3a2faba5..e69de29b 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -1,13 +0,0 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 5f9c6dd8..f5404db2 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -1,54 +1,20 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations -import asyncio -import contextlib -import functools -import logging -import sys -import time -import typing as t +import asyncio, contextlib, functools, logging, sys, time, typing as t from abc import ABC, abstractmethod - -import attr -import docker -import docker.errors -import docker.types -import orjson -import pytest +import attr, docker, docker.errors, docker.types, orjson, pytest, openllm from syrupy.extensions.json import JSONSnapshotExtension - -import openllm from openllm._llm import normalise_model_name +from openllm._typing_compat import DictStrAny, ListAny logger = logging.getLogger(__name__) if t.TYPE_CHECKING: import subprocess - from syrupy.assertion import SnapshotAssertion from syrupy.types import PropertyFilter, PropertyMatcher, SerializableData, SerializedData - from openllm._configuration import GenerationConfig - from openllm._types import DictStrAny, ListAny from openllm.client import BaseAsyncClient -else: - DictStrAny = dict - ListAny = list - class ResponseComparator(JSONSnapshotExtension): def serialize(self, data: SerializableData, *, exclude: PropertyFilter | None = None, matcher: PropertyMatcher | None = None,) -> SerializedData: if openllm.utils.LazyType(ListAny).isinstance(data): diff --git a/tests/models/flan_t5_test.py b/tests/models/flan_t5_test.py index c0921b43..189fc79f 100644 --- a/tests/models/flan_t5_test.py +++ b/tests/models/flan_t5_test.py @@ -1,17 +1,3 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations import typing as t diff --git a/tests/models/opt_test.py b/tests/models/opt_test.py index acb75af6..49c4101e 100644 --- a/tests/models/opt_test.py +++ b/tests/models/opt_test.py @@ -1,16 +1,3 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from __future__ import annotations import typing as t diff --git a/tests/models_test.py b/tests/models_test.py index 1fbf2362..0017da89 100644 --- a/tests/models_test.py +++ b/tests/models_test.py @@ -1,17 +1,3 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations import os import typing as t diff --git a/tests/package_test.py b/tests/package_test.py index 00c67a48..5fc6bd17 100644 --- a/tests/package_test.py +++ b/tests/package_test.py @@ -1,17 +1,3 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations import functools import os diff --git a/tests/strategies_test.py b/tests/strategies_test.py index de4293d2..d9520a5a 100644 --- a/tests/strategies_test.py +++ b/tests/strategies_test.py @@ -1,17 +1,3 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import annotations import os import typing as t diff --git a/tools/dependencies.py b/tools/dependencies.py index 167c50b3..f0c8dda4 100755 --- a/tools/dependencies.py +++ b/tools/dependencies.py @@ -120,6 +120,7 @@ _BASE_DEPENDENCIES = [ Dependencies(name="httpx"), Dependencies(name="click", lower_constraint="8.1.3"), Dependencies(name="typing_extensions"), + Dependencies(name="mypy_extensions"), # for mypyc compilation Dependencies(name="ghapi"), Dependencies(name="cuda-python", platform=("Darwin", "ne")), Dependencies(name="bitsandbytes", upper_constraint="0.42"), # 0.41 works with CUDA 11.8 diff --git a/tools/update-brew-tap.py b/tools/update-brew-tap.py index 1a145b5f..ed9091da 100755 --- a/tools/update-brew-tap.py +++ b/tools/update-brew-tap.py @@ -1,19 +1,13 @@ #!/usr/bin/env python3 from __future__ import annotations -import os -import typing as t +import os, typing as t, fs from pathlib import Path - -import fs from ghapi.all import GhApi from jinja2 import Environment from jinja2.loaders import FileSystemLoader from plumbum.cmd import curl, cut, shasum -if t.TYPE_CHECKING: - from plumbum.commands.base import Pipeline - - from openllm._types import DictStrAny +if t.TYPE_CHECKING: from plumbum.commands.base import Pipeline # get git root from this file ROOT = Path(__file__).parent.parent @@ -36,7 +30,7 @@ def main() -> int: _info = api.repos.get() release_tag = api.repos.get_latest_release().name - shadict: DictStrAny = {k: get_release_hash_command(determine_release_url(_info.svn_url, release_tag, k), release_tag)().strip() for k in _gz_strategies} + shadict: dict[str, t.Any] = {k: get_release_hash_command(determine_release_url(_info.svn_url, release_tag, k), release_tag)().strip() for k in _gz_strategies} shadict["archive"] = get_release_hash_command(determine_release_url(_info.svn_url, release_tag, "archive"), release_tag)().strip() ENVIRONMENT = Environment(extensions=["jinja2.ext.do", "jinja2.ext.loopcontrols", "jinja2.ext.debug"], trim_blocks=True, lstrip_blocks=True, loader=FileSystemLoader((ROOT / "Formula").__fspath__(), followlinks=True)) diff --git a/tools/update-config-stubs.py b/tools/update-config-stubs.py index 1e7ab1e5..957ee0bd 100755 --- a/tools/update-config-stubs.py +++ b/tools/update-config-stubs.py @@ -28,14 +28,16 @@ _value_docstring = { This could be one of the keys in 'self.model_ids' or custom users model. This field is required when defining under '__config__'. - """, "model_ids": """A list of supported pretrained models tag for this given runnable. + """, + "model_ids": """A list of supported pretrained models tag for this given runnable. For example: For FLAN-T5 impl, this would be ["google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl"] This field is required when defining under '__config__'. - """, "architecture": """The model architecture that is supported by this LLM. + """, + "architecture": """The model architecture that is supported by this LLM. Note that any model weights within this architecture generation can always be run and supported by this LLM. @@ -44,16 +46,29 @@ _value_docstring = { ```bash openllm start gpt-neox --model-id stabilityai/stablelm-tuned-alpha-3b - ```""", "default_implementation": """The default runtime to run this LLM. By default, it will be PyTorch (pt) for most models. For some models, such as Llama, it will use `vllm` or `flax`. + ```""", + "default_implementation": """The default runtime to run this LLM. By default, it will be PyTorch (pt) for most models. For some models, such as Llama, it will use `vllm` or `flax`. It is a dictionary of key as the accelerator spec in k4s ('cpu', 'nvidia.com/gpu', 'amd.com/gpu', 'cloud-tpus.google.com/v2', ...) and the values as supported OpenLLM Runtime ('flax', 'tf', 'pt', 'vllm') - """, "url": """The resolved url for this LLMConfig.""", "requires_gpu": """Determines if this model is only available on GPU. By default it supports GPU and fallback to CPU.""", "trust_remote_code": """Whether to always trust remote code""", "service_name": """Generated service name for this LLMConfig. By default, it is 'generated_{model_name}_service.py'""", + """, + "url": """The resolved url for this LLMConfig.""", + "requires_gpu": """Determines if this model is only available on GPU. By default it supports GPU and fallback to CPU.""", + "trust_remote_code": """Whether to always trust remote code""", + "service_name": """Generated service name for this LLMConfig. By default, it is 'generated_{model_name}_service.py'""", "requirements": """The default PyPI requirements needed to run this given LLM. By default, we will depend on - bentoml, torch, transformers.""", "bettertransformer": """Whether to use BetterTransformer for this given LLM. This depends per model architecture. By default, we will use BetterTransformer for T5 and StableLM models, and set to False for every other models.""", "model_type": """The model type for this given LLM. By default, it should be causal language modeling. + bentoml, torch, transformers.""", + "bettertransformer": """Whether to use BetterTransformer for this given LLM. This depends per model architecture. By default, we will use BetterTransformer for T5 and StableLM models, and set to False for every other models.""", + "model_type": """The model type for this given LLM. By default, it should be causal language modeling. Currently supported 'causal_lm' or 'seq2seq_lm' - """, "runtime": """The runtime to use for this model. Possible values are `transformers` or `ggml`. See Llama for more information.""", "name_type": """The default name typed for this model. "dasherize" will convert the name to lowercase and + """, + "runtime": """The runtime to use for this model. Possible values are `transformers` or `ggml`. See Llama for more information.""", + "name_type": """The default name typed for this model. "dasherize" will convert the name to lowercase and replace spaces with dashes. "lowercase" will convert the name to lowercase. If this is not set, then both - `model_name` and `start_name` must be specified.""", "model_name": """The normalized version of __openllm_start_name__, determined by __openllm_name_type__""", "start_name": """Default name to be used with `openllm start`""", "env": """A EnvVarMixin instance for this LLMConfig.""", "timeout": """The default timeout to be set for this given LLM.""", + `model_name` and `start_name` must be specified.""", + "model_name": """The normalized version of __openllm_start_name__, determined by __openllm_name_type__""", + "start_name": """Default name to be used with `openllm start`""", + "env": """A EnvVarMixin instance for this LLMConfig.""", + "timeout": """The default timeout to be set for this given LLM.""", "workers_per_resource": """The number of workers per resource. This is used to determine the number of workers to use for this model. For example, if this is set to 0.5, then OpenLLM will use 1 worker per 2 resources. If this is set to 1, then OpenLLM will use 1 worker per resource. If this is set to 2, then OpenLLM will use 2 workers per resource. @@ -62,7 +77,9 @@ _value_docstring = { https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy for more details. By default, it is set to 1. - """, "fine_tune_strategies": """The fine-tune strategies for this given LLM.""", "tokenizer_class": """Optional tokenizer class for this given LLM. See Llama for example.""", + """, + "fine_tune_strategies": """The fine-tune strategies for this given LLM.""", + "tokenizer_class": """Optional tokenizer class for this given LLM. See Llama for example.""", } _transformed = {"fine_tune_strategies": "t.Dict[AdapterType, FineTuneConfig]"} diff --git a/tools/update-init-import.py b/tools/update-init-import.py deleted file mode 100755 index 2f2b9c3e..00000000 --- a/tools/update-init-import.py +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations -import importlib -import itertools -from pathlib import Path - -_client_all = Path(__file__).parent.parent/"src"/"openllm"/"client.py" - -def main() -> int: - mod = importlib.import_module("openllm.client") - _all = [f'"{i}"' for i in itertools.chain.from_iterable(mod._import_structure.values())] - with _client_all.open("r") as f: processed = f.readlines() - processed = processed[:-1] + [f"__all__=[{','.join(sorted(_all))}]\n"] - with _client_all.open("w") as f: f.writelines(processed) - return 0 - -if __name__ == "__main__": raise SystemExit(main()) diff --git a/tools/update-models-import.py b/tools/update-models-import.py index 0d26b0f7..f1b9ef5a 100755 --- a/tools/update-models-import.py +++ b/tools/update-models-import.py @@ -9,7 +9,7 @@ def create_module_import() -> str: r = [f'"{p.name}"' for p in _TARGET_FILE.parent.glob('*/') if p.name not in ['__pycache__', '__init__.py', '.DS_Store']] return f"_MODELS: set[str] = {{{', '.join(sorted(r))}}}" def create_stubs_import() -> list[str]: return ["if t.TYPE_CHECKING: from . import "+",".join([f"{p.name} as {p.name}" for p in sorted(_TARGET_FILE.parent.glob("*/")) if p.name not in {"__pycache__", "__init__.py", ".DS_Store"}]), - '__lazy=LazyModule(__name__, globals()["__file__"], {k: [] for k in _MODELS})', "__all__=__lazy.__all__", "__dir__=__lazy.__dir__", "__getattr__=__lazy.__getattr__\n"] + '__lazy=LazyModule(__name__, os.path.abspath("__file__"), {k: [] for k in _MODELS})', "__all__=__lazy.__all__", "__dir__=__lazy.__dir__", "__getattr__=__lazy.__getattr__\n"] def main() -> int: _path = os.path.join(os.path.basename(os.path.dirname(__file__)), os.path.basename(__file__)) @@ -17,7 +17,7 @@ def main() -> int: f"# This file is generated by {_path}. DO NOT EDIT MANUALLY!", f"# To update this, run ./{_path}", "from __future__ import annotations", - "import typing as t", + "import typing as t, os", "from openllm.utils import LazyModule", create_module_import(), *create_stubs_import(), diff --git a/typings/attr/__init__.pyi b/typings/attr/__init__.pyi index 871a19a3..03d71d5e 100644 --- a/typings/attr/__init__.pyi +++ b/typings/attr/__init__.pyi @@ -1,4 +1,5 @@ import enum +import sys from typing import ( Any, Callable, @@ -8,19 +9,20 @@ from typing import ( Literal, Mapping, Optional, - ParamSpec, Protocol, Sequence, Tuple, Type, - TypeAlias, - TypeGuard, TypeVar, Union, - dataclass_transform, overload, ) +if sys.version_info[:2] >= (3, 11): + from typing import ParamSpec, TypeAlias, TypeGuard, dataclass_transform +else: + from typing_extensions import ParamSpec, TypeAlias, TypeGuard, dataclass_transform + from . import ( converters as converters, exceptions as exceptions, @@ -72,9 +74,9 @@ def Factory(factory: Callable[[Any], _T], takes_self: Literal[True]) -> _T: ... @overload def Factory(factory: Callable[[], _T], takes_self: Literal[False]) -> _T: ... -class _CountingAttr(Generic[_T]): +class _CountingAttr: counter: int - _default: _T + _default: Any repr: _ReprArgType cmp: _EqOrderType eq: _EqOrderType @@ -85,8 +87,8 @@ class _CountingAttr(Generic[_T]): init: bool converter: _ConverterType | None metadata: dict[Any, Any] - _validator: _ValidatorType[_T] | None - type: type[_T] | None + _validator: _ValidatorType[Any] | None + type: type[Any] | None kw_only: bool on_setattr: _OnSetAttrType alias: str | None @@ -109,7 +111,7 @@ class Attribute(Generic[_T]): alias: str | None def evolve(self, **changes: Any) -> Attribute[Any]: ... @classmethod - def from_counting_attr(cls, name: str, ca: _CountingAttr[_T], type: Type[Any] | None = None) -> Attribute[_T]: ... + def from_counting_attr(cls, name: str, ca: _CountingAttr, type: Type[Any] | None = None) -> Attribute[_T]: ... # NOTE: We had several choices for the annotation to use for type arg: # 1) Type[_T] @@ -534,12 +536,10 @@ def assoc(inst: _T, **changes: Any) -> _T: ... def evolve(inst: _T, **changes: Any) -> _T: ... # _config -- - def set_run_validators(run: bool) -> None: ... def get_run_validators() -> bool: ... # aliases -- - s = attrs attributes = attrs ib = attrib @@ -547,8 +547,7 @@ attr = attrib dataclass = attrs # Technically, partial(attrs, auto_attribs=True) ;) class ReprProtocol(Protocol): - def __call__(__self, self: Any) -> str: ... - + def __call__(__self, self: Any) -> str: ... def _make_init( cls: type[AttrsInstance], attrs: tuple[Attribute[Any], ...], @@ -565,9 +564,9 @@ def _make_init( def _make_repr(attrs: tuple[Attribute[Any]], ns: str | None, cls: AttrsInstance) -> ReprProtocol: ... def _transform_attrs( cls: type[AttrsInstance], - these: dict[str, _CountingAttr[_T]] | None, + these: dict[str, _CountingAttr] | None, auto_attribs: bool, kw_only: bool, collect_by_mro: bool, field_transformer: _FieldTransformer | None, -) -> tuple[tuple[Attribute[_T], ...], tuple[Attribute[_T], ...], dict[Attribute[_T], type[Any]]]: ... +) -> tuple[tuple[Attribute[Any], ...], tuple[Attribute[Any], ...], dict[Attribute[Any], type[Any]]]: ... diff --git a/typings/attr/_cmp.pyi b/typings/attr/_cmp.pyi index 9341dcbc..3ef1ddf2 100644 --- a/typings/attr/_cmp.pyi +++ b/typings/attr/_cmp.pyi @@ -1,13 +1,11 @@ -from typing import Any, Callable, Optional, TypeAlias +import sys +from typing import Any, Callable, Optional + +if sys.version_info[:2] >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias _CompareWithType: TypeAlias = Callable[[Any, Any], bool] -def cmp_using( - eq: Optional[_CompareWithType] = ..., - lt: Optional[_CompareWithType] = ..., - le: Optional[_CompareWithType] = ..., - gt: Optional[_CompareWithType] = ..., - ge: Optional[_CompareWithType] = ..., - require_same_type: bool = ..., - class_name: str = ..., -) -> type[Any]: ... +def cmp_using(eq: Optional[_CompareWithType] = ..., lt: Optional[_CompareWithType] = ..., le: Optional[_CompareWithType] = ..., gt: Optional[_CompareWithType] = ..., ge: Optional[_CompareWithType] = ..., require_same_type: bool = ..., class_name: str = ...,) -> type[Any]: ... diff --git a/typings/click_option_group/_core.pyi b/typings/click_option_group/_core.pyi index 11277319..8de264a6 100644 --- a/typings/click_option_group/_core.pyi +++ b/typings/click_option_group/_core.pyi @@ -1,6 +1,9 @@ -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Tuple, TypeAlias, TypeVar, Union - -import click +import sys, click +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Set, Tuple, TypeVar, Union +if sys.version_info[:2] >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias AnyCallable: TypeAlias = Callable[..., Any] _FC = TypeVar("_FC", bound=Union[AnyCallable, click.Command]) diff --git a/typings/deepmerge/merger.pyi b/typings/deepmerge/merger.pyi index 1eca04c8..cdae8ab2 100644 --- a/typings/deepmerge/merger.pyi +++ b/typings/deepmerge/merger.pyi @@ -1,5 +1,9 @@ -from typing import Any, Dict, List, Tuple, TypeAlias, Union - +import sys +from typing import Any, Dict, List, Tuple, Union +if sys.version_info[:2] >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias from .strategy.core import StrategyList from .strategy.dict import DictStrategies from .strategy.list import ListStrategies diff --git a/typings/deepmerge/strategy/core.pyi b/typings/deepmerge/strategy/core.pyi index 7915d442..2117dd88 100644 --- a/typings/deepmerge/strategy/core.pyi +++ b/typings/deepmerge/strategy/core.pyi @@ -1,4 +1,10 @@ -from typing import Any, Callable, List, Optional, TypeAlias, Union +import sys +from typing import Any, Callable, List, Optional, Union +if sys.version_info[:2] >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + _StringOrFunction: TypeAlias = Union[str, Callable[..., Any]] STRATEGY_END: object = ... diff --git a/typings/nbformat/_struct.pyi b/typings/nbformat/_struct.pyi index ac50e0af..c39bfe08 100644 --- a/typings/nbformat/_struct.pyi +++ b/typings/nbformat/_struct.pyi @@ -1,8 +1,7 @@ -"""A dict subclass that supports attribute style access. - -Can probably be replaced by types.SimpleNamespace from Python 3.3 -""" -from typing import Any, Dict, Self +import sys +from typing import Any, Dict +if sys.version_info[:2] >= (3,11): from typing import Self +else: from typing_extensions import Self class Struct(Dict[str, Any]): _allownew: bool = True diff --git a/typings/rsmiBindings.pyi b/typings/rsmiBindings.pyi index 355ff88b..f30195f6 100644 --- a/typings/rsmiBindings.pyi +++ b/typings/rsmiBindings.pyi @@ -1,11 +1,16 @@ # See https://github.com/RadeonOpenCompute/rocm_smi_lib/blob/master/python_smi_tools/rsmiBindings.py import ctypes -from typing import Any, Literal, LiteralString +import sys +from typing import Any, Literal + +if sys.version_info[:2] >= (3, 11): + from typing import LiteralString +else: + from typing_extensions import LiteralString class rocmsmi(ctypes.CDLL): - @staticmethod - def rsmi_num_monitor_devices(num_devices: ctypes._CArgObject) -> Any: ... - + @staticmethod + def rsmi_num_monitor_devices(num_devices: ctypes._CArgObject) -> Any: ... # Device ID dv_id: ctypes.c_uint64 = ... # GPU ID @@ -17,23 +22,23 @@ RSMI_MAX_FAN_SPEED: Literal[255] = ... RSMI_NUM_VOLTAGE_CURVE_POINTS: Literal[3] = ... class rsmi_status_t(ctypes.c_int): - RSMI_STATUS_SUCCESS: Literal[0x0] = ... - RSMI_STATUS_INVALID_ARGS: Literal[0x1] = ... - RSMI_STATUS_NOT_SUPPORTED: Literal[0x2] = ... - RSMI_STATUS_FILE_ERROR: Literal[0x3] = ... - RSMI_STATUS_PERMISSION: Literal[0x4] = ... - RSMI_STATUS_OUT_OF_RESOURCES: Literal[0x5] = ... - RSMI_STATUS_INTERNAL_EXCEPTION: Literal[0x6] = ... - RSMI_STATUS_INPUT_OUT_OF_BOUNDS: Literal[0x7] = ... - RSMI_STATUS_INIT_ERROR: Literal[0x8] = ... - RSMI_INITIALIZATION_ERROR = RSMI_STATUS_INIT_ERROR - RSMI_STATUS_NOT_YET_IMPLEMENTED: Literal[0x9] = ... - RSMI_STATUS_NOT_FOUND: Literal[0xA] = ... - RSMI_STATUS_INSUFFICIENT_SIZE: Literal[0xB] = ... - RSMI_STATUS_INTERRUPT: Literal[0xC] = ... - RSMI_STATUS_UNEXPECTED_SIZE: Literal[0xD] = ... - RSMI_STATUS_NO_DATA: Literal[0xE] = ... - RSMI_STATUS_UNKNOWN_ERROR: Literal[0xFFFFFFFF] = ... + RSMI_STATUS_SUCCESS: Literal[0x0] = ... + RSMI_STATUS_INVALID_ARGS: Literal[0x1] = ... + RSMI_STATUS_NOT_SUPPORTED: Literal[0x2] = ... + RSMI_STATUS_FILE_ERROR: Literal[0x3] = ... + RSMI_STATUS_PERMISSION: Literal[0x4] = ... + RSMI_STATUS_OUT_OF_RESOURCES: Literal[0x5] = ... + RSMI_STATUS_INTERNAL_EXCEPTION: Literal[0x6] = ... + RSMI_STATUS_INPUT_OUT_OF_BOUNDS: Literal[0x7] = ... + RSMI_STATUS_INIT_ERROR: Literal[0x8] = ... + RSMI_INITIALIZATION_ERROR = RSMI_STATUS_INIT_ERROR + RSMI_STATUS_NOT_YET_IMPLEMENTED: Literal[0x9] = ... + RSMI_STATUS_NOT_FOUND: Literal[0xA] = ... + RSMI_STATUS_INSUFFICIENT_SIZE: Literal[0xB] = ... + RSMI_STATUS_INTERRUPT: Literal[0xC] = ... + RSMI_STATUS_UNEXPECTED_SIZE: Literal[0xD] = ... + RSMI_STATUS_NO_DATA: Literal[0xE] = ... + RSMI_STATUS_UNKNOWN_ERROR: Literal[0xFFFFFFFF] = ... # Dictionary of rsmi ret codes and it's verbose output rsmi_status_verbose_err_out: dict[LiteralString, LiteralString] = ... diff --git a/typings/simple_di/__init__.pyi b/typings/simple_di/__init__.pyi new file mode 100644 index 00000000..c0142362 --- /dev/null +++ b/typings/simple_di/__init__.pyi @@ -0,0 +1,29 @@ +from typing import Any, Callable, Generator, Generic, Tuple, TypeVar, Union, overload + +from _typeshed import Incomplete + +class _SentinelClass: ... +_VT = TypeVar("_VT") + +class Provider(Generic[_VT]): + STATE_FIELDS: Tuple[str, ...] + def __init__(self) -> None: ... + def set(self, value: Union[_SentinelClass, _VT]) -> None: ... + def patch(self, value: Union[_SentinelClass, _VT]) -> Generator[None, None, None]: ... + def get(self) -> _VT: ... + def reset(self) -> None: ... + +class _ProvideClass: + def __getitem__(self, provider: Provider[_VT]) -> _VT: ... + +Provide: Incomplete +_AnyCallable = TypeVar("_AnyCallable", bound=Callable[..., Any]) + + +@overload +def inject(func: _AnyCallable) -> _AnyCallable: ... +@overload +def inject(func: None = ..., squeeze_none: bool = ...) -> Callable[[_AnyCallable], _AnyCallable]: ... +def sync_container(from_: Any, to_: Any) -> None: ... + +container: Incomplete diff --git a/typings/simple_di/providers.pyi b/typings/simple_di/providers.pyi new file mode 100644 index 00000000..fe3b8228 --- /dev/null +++ b/typings/simple_di/providers.pyi @@ -0,0 +1,54 @@ +import sys +from typing import ( + Any, + Callable as CallableType, + Dict, + Tuple, + Union, +) +if sys.version_info[:2] >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + + +from _typeshed import Incomplete + +from . import _VT, Provider, _SentinelClass + +class Placeholder(Provider[_VT]): ... + +class Static(Provider[_VT]): + STATE_FIELDS: Tuple[str, ...] + def __init__(self, value: _VT) -> None: ... + +class Factory(Provider[_VT]): + STATE_FIELDS: Tuple[str, ...] + def __init__(self, func: CallableType[..., _VT], *args: Any, **kwargs: Any) -> None: ... + +class SingletonFactory(Factory[_VT]): + STATE_FIELDS: Tuple[str, ...] + def __init__(self, func: CallableType[..., _VT], *args: Any, **kwargs: Any) -> None: ... +Callable = Factory +MemoizedCallable = SingletonFactory +ConfigDictType: TypeAlias = Dict[Union[str, int], Any] +PathItemType: TypeAlias = Union[int, str, Provider[int], Provider[str]] + +class Configuration(Provider[ConfigDictType]): + STATE_FIELDS: Tuple[str, ...] + fallback: Incomplete + def __init__(self, data: Union[_SentinelClass, ConfigDictType] = ..., fallback: Any = ...) -> None: ... + def set(self, value: Union[_SentinelClass, ConfigDictType]) -> None: ... + def get(self) -> Union[ConfigDictType, Any]: ... + def reset(self) -> None: ... + def __getattr__(self, name: str) -> _ConfigurationItem: ... + def __getitem__(self, key: PathItemType) -> _ConfigurationItem: ... + +class _ConfigurationItem(Provider[Any]): + STATE_FIELDS: Tuple[str, ...] + def __init__(self, config: Configuration, path: Tuple[PathItemType, ...]) -> None: ... + def set(self, value: Any) -> None: ... + def get(self) -> Any: ... + def reset(self) -> None: ... + def __getattr__(self, name: str) -> _ConfigurationItem: ... + def __getitem__(self, key: PathItemType) -> _ConfigurationItem: ...