mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-12 02:20:32 -04:00
feat(test): snapshot testing (#107)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
2
.gitattributes
vendored
2
.gitattributes
vendored
@@ -1,3 +1,5 @@
|
||||
nightly-requirements.txt linguist-generated=true
|
||||
nightly-requirements-gpu.txt linguist-generated=true
|
||||
tests/models/__snapshots__/* linguist-generated=true
|
||||
typings/**/*.pyi linguist-generated=true
|
||||
* text=auto eol=lf
|
||||
|
||||
11
.github/actions/setup-repo/action.yml
vendored
11
.github/actions/setup-repo/action.yml
vendored
@@ -31,10 +31,6 @@ runs:
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
architecture: ${{ inputs.architecture }}
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: '17'
|
||||
- name: Get cache key prefix
|
||||
id: get-cache-key-prefix
|
||||
shell: bash
|
||||
@@ -54,10 +50,3 @@ runs:
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: pip install hatch towncrier
|
||||
- name: Install pyright
|
||||
shell: bash
|
||||
run: npm install -g npm@^7 pyright
|
||||
- name: Setup bufbuild/buf
|
||||
uses: bufbuild/buf-setup-action@v1.20.0
|
||||
with:
|
||||
github_token: ${{ github.token }}
|
||||
|
||||
17
.github/workflows/ci.yml
vendored
17
.github/workflows/ci.yml
vendored
@@ -29,6 +29,20 @@ defaults:
|
||||
run:
|
||||
shell: bash --noprofile --norc -exo pipefail {0}
|
||||
jobs:
|
||||
quality:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event_name == 'pull_request'
|
||||
name: quality-check
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup CI
|
||||
uses: ./.github/actions/setup-repo
|
||||
with:
|
||||
python-version: ${{ env.STABLE_PYTHON_VERSION }}
|
||||
- name: Run type check
|
||||
run: hatch run typing
|
||||
tests:
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event_name == 'pull_request' || github.event_name == 'push' }}
|
||||
@@ -47,7 +61,7 @@ jobs:
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Run tests
|
||||
run: hatch run full
|
||||
run: hatch run tests:python
|
||||
- name: Disambiguate coverage filename
|
||||
run: mv .coverage ".coverage.${{ matrix.os }}.${{ matrix.python-version }}"
|
||||
- name: Upload coverage data
|
||||
@@ -99,6 +113,7 @@ jobs:
|
||||
needs:
|
||||
- coverage
|
||||
- tests
|
||||
- quality
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Decide whether the needed jobs succeeded or failed
|
||||
|
||||
46
.github/workflows/cleanup-cache.yml
vendored
46
.github/workflows/cleanup-cache.yml
vendored
@@ -1,46 +0,0 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: cache-cleanup
|
||||
on:
|
||||
pull_request:
|
||||
types:
|
||||
- closed
|
||||
jobs:
|
||||
cleanup:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository_owner == 'bentoml'
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
- name: Cleanup
|
||||
run: |
|
||||
gh extension install actions/gh-actions-cache
|
||||
|
||||
REPO=${{ github.repository }}
|
||||
BRANCH="refs/pull/${{ github.event.pull_request.number }}/merge"
|
||||
|
||||
echo "Fetching list of cache key"
|
||||
cacheKeysForPR=$(gh actions-cache list -R $REPO -B $BRANCH | cut -f 1 )
|
||||
|
||||
## Setting this to not fail the workflow while deleting cache keys.
|
||||
set +e
|
||||
echo "Deleting caches..."
|
||||
for cacheKey in $cacheKeysForPR
|
||||
do
|
||||
gh actions-cache delete $cacheKey -R $REPO -B $BRANCH --confirm
|
||||
done
|
||||
echo "Done"
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -14,16 +14,16 @@
|
||||
|
||||
ci:
|
||||
autoupdate_schedule: weekly
|
||||
skip: [check-models-table-update, check-models-table-update, changelog-dry-run]
|
||||
autofix_commit_msg: "ci: auto fixes from pre-commit.ci\nFor more information, see https://pre-commit.ci"
|
||||
skip: [check-models-table-update, check-models-table-update, changelog-dry-run, typecheck]
|
||||
autofix_commit_msg: "ci: auto fixes from pre-commit.ci\n\nFor more information, see https://pre-commit.ci"
|
||||
autoupdate_commit_msg: 'ci: pre-commit autoupdate [pre-commit.ci]'
|
||||
exclude: '.*\.(css|js|svg)$'
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: 'v0.0.275'
|
||||
rev: 'v0.0.277'
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix, --show-fixes]
|
||||
args: [--exit-non-zero-on-fix, --show-fixes]
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
@@ -37,6 +37,13 @@ repos:
|
||||
args: [--config=pyproject.toml]
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: typecheck
|
||||
name: type-check
|
||||
entry: pyright src/openllm --level error
|
||||
types: [python]
|
||||
language: node
|
||||
pass_filenames: false
|
||||
additional_dependencies: ['pyright@1.1.316']
|
||||
- id: check-license-header
|
||||
name: check for license headers
|
||||
entry: ./tools/assert-license-headers
|
||||
@@ -69,3 +76,10 @@ repos:
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
exclude: |
|
||||
(?x)^(
|
||||
tests/models/.*
|
||||
)$
|
||||
- id: check-yaml
|
||||
args: ['--unsafe']
|
||||
- id: check-toml
|
||||
|
||||
@@ -100,28 +100,25 @@ After setting up your environment, here's how you can start contributing:
|
||||
3. Run all formatter and linter with `hatch`:
|
||||
|
||||
```bash
|
||||
hatch run fmt
|
||||
hatch run quality
|
||||
```
|
||||
4. Write tests that verify your feature or fix (see
|
||||
[Writing Tests](#writing-tests) below).
|
||||
5. Run all tests to ensure your changes haven't broken anything:
|
||||
|
||||
```bash
|
||||
hatch run full
|
||||
hatch run tests:python
|
||||
```
|
||||
|
||||
6. Commit your changes:
|
||||
|
||||
```bash
|
||||
git commit -m "Add my feature"
|
||||
```
|
||||
|
||||
7. Push your changes to your fork:
|
||||
|
||||
```bash
|
||||
git push origin feature/my-feature
|
||||
```
|
||||
|
||||
8. Submit a Pull Request on GitHub.
|
||||
|
||||
## Using a custom fork
|
||||
@@ -141,7 +138,13 @@ directory and their filenames start with `test_`.
|
||||
Run all tests with:
|
||||
|
||||
```bash
|
||||
hatch run full
|
||||
hatch run tests:python
|
||||
```
|
||||
|
||||
Run snapshot testing for model outputs:
|
||||
|
||||
```bash
|
||||
hatch run tests:models
|
||||
```
|
||||
|
||||
## Releasing a New Version
|
||||
|
||||
@@ -11,6 +11,9 @@
|
||||
</a><a href="https://l.bentoml.com/join-openllm-discord">
|
||||
<img src="https://badgen.net/badge/icon/OpenLLM/7289da?icon=discord&label=Join%20Us" alt="Discord" />
|
||||
</a><br>
|
||||
</a><a href="https://pdm.fming.dev">
|
||||
<img src="https://img.shields.io/badge/pdm-managed-blueviolet" alt="PDM" />
|
||||
</a><br>
|
||||
<p>An open platform for operating large language models (LLMs) in production.</br>
|
||||
Fine-tune, serve, deploy, and monitor any LLMs with ease.</p>
|
||||
<i></i>
|
||||
|
||||
26
changelog.d/107.fix.md
Normal file
26
changelog.d/107.fix.md
Normal file
@@ -0,0 +1,26 @@
|
||||
Fixes relative model_id handling for running LLM within the container.
|
||||
|
||||
Added support for building container directly with `openllm build`. Users now
|
||||
can do `openllm build --format=container`:
|
||||
|
||||
```bash
|
||||
openllm build flan-t5 --format=container
|
||||
```
|
||||
|
||||
This is equivalent to:
|
||||
|
||||
```bash
|
||||
openllm build flan-t5 && bentoml containerize google-flan-t5-large-service
|
||||
```
|
||||
|
||||
Added Snapshot testing and more robust edge cases for model testing
|
||||
|
||||
General improvement in `openllm.LLM.import_model` where it will parse santised
|
||||
parameters automatically.
|
||||
|
||||
Fixes `openllm start <bento>` to use correct `model_id`, ignoring `--model-id`
|
||||
(The correct behaviour)
|
||||
|
||||
Fixes `--workers-per-resource conserved` to respect `--device`
|
||||
|
||||
Added initial interface for `LLM.embeddings`
|
||||
@@ -13,9 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
@@ -36,11 +33,9 @@ class Query(BaseModel):
|
||||
|
||||
|
||||
def gen_llm(model_name: str, model_id: str | None = None) -> OpenLLM:
|
||||
args = [sys.executable, "-m", "openllm", "download", model_name]
|
||||
if model_id:
|
||||
args += ["--model-id", model_id]
|
||||
subprocess.check_output(args)
|
||||
return OpenLLM(model_name=model_name, model_id=model_id, embedded=False)
|
||||
lc_llm = OpenLLM(model_name=model_name, model_id=model_id, embedded=False)
|
||||
lc_llm.runner.download_model()
|
||||
return lc_llm
|
||||
|
||||
|
||||
llm = gen_llm("dolly-v2", model_id="databricks/dolly-v2-7b")
|
||||
|
||||
27
hatch.toml
27
hatch.toml
@@ -8,8 +8,6 @@ dependencies = [
|
||||
"tomlkit",
|
||||
# NOTE: Using under ./tools/update-readme.py
|
||||
"markdown-it-py",
|
||||
# NOTE: pyright for type
|
||||
"pyright",
|
||||
# NOTE: Tests strategies with Hypothesis and pytest, and snapshot testing with syrupy
|
||||
"coverage[toml]>=6.5",
|
||||
"filelock>=3.7.1",
|
||||
@@ -26,25 +24,28 @@ dependencies = [
|
||||
]
|
||||
features = ['flan-t5']
|
||||
[envs.default.scripts]
|
||||
_run_script = "pytest --cov --cov-report={env:COVERAGE_REPORT:term-missing} --cov-config=pyproject.toml"
|
||||
changelog = "towncrier build --version main --draft"
|
||||
fmt = ["tools", "pre-commit run --all-files"]
|
||||
full = "_run_script --reruns 5 --reruns-delay 3 -r aR {args:tests}"
|
||||
setup = "pre-commit install"
|
||||
tools = [
|
||||
quality = [
|
||||
"./tools/update-readme.py",
|
||||
"./tools/update-optional-dependencies.py",
|
||||
"./tools/update-config-stubs.py",
|
||||
"./tools/update-models-import.py",
|
||||
"- ./tools/add-license-headers .",
|
||||
"pre-commit run --all-files",
|
||||
]
|
||||
typing = "pyright {args:src/openllm tests}"
|
||||
[envs.test.overrides]
|
||||
setup = "pre-commit install"
|
||||
typing = "pre-commit run typecheck --all-files"
|
||||
[envs.tests]
|
||||
extra-dependencies = [
|
||||
# NOTE: interact with docker for container tests.
|
||||
"docker",
|
||||
]
|
||||
[envs.tests.scripts]
|
||||
_run_script = "pytest --cov --cov-report={env:COVERAGE_REPORT:term-missing} --cov-config=pyproject.toml"
|
||||
models = "_run_script -r aR {args:tests/models}"
|
||||
python = "_run_script --reruns 5 --reruns-delay 3 --ignore tests/models -n 3 -r aR {args:tests}"
|
||||
[envs.tests.overrides]
|
||||
env.GITHUB_ACTIONS.env-vars = "COVERAGE_REPORT="
|
||||
env.HERMETIC_TESTS.type = [{ value = "container", if = ["true"] }, "virtual"]
|
||||
[envs.test.scripts]
|
||||
[[envs.test.matrix]]
|
||||
python = ["3.8", "3.9", "3.10", "3.11"]
|
||||
[envs.coverage]
|
||||
dependencies = ["coverage[toml]>=6.5", "lxml", "orjson"]
|
||||
detached = true
|
||||
|
||||
131
pyproject.toml
131
pyproject.toml
@@ -72,12 +72,12 @@ all = [
|
||||
"openllm[falcon]",
|
||||
"openllm[mpt]",
|
||||
"openllm[starcoder]",
|
||||
"openllm[ggml]",
|
||||
"openllm[playground]",
|
||||
"openllm[fine-tune]",
|
||||
"openllm[agents]",
|
||||
"openllm[flan-t5]",
|
||||
"openllm[openai]",
|
||||
"openllm[playground]",
|
||||
"openllm[flan-t5]",
|
||||
"openllm[agents]",
|
||||
"openllm[ggml]",
|
||||
]
|
||||
chatglm = ["cpm-kernels", "sentencepiece"]
|
||||
falcon = ["einops", "xformers", "safetensors"]
|
||||
@@ -155,7 +155,7 @@ verbose = 2
|
||||
whitelist-regex = ["test_.*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = ["-rfEX", "-pno:warnings"]
|
||||
addopts = ["-rfEX", "-pno:warnings", "--snapshot-warn-unused"]
|
||||
python_files = ["test_*.py", "*_test.py"]
|
||||
testpaths = ["tests"]
|
||||
|
||||
@@ -183,70 +183,125 @@ line-length = 119
|
||||
target-version = ["py311"]
|
||||
|
||||
[tool.ruff]
|
||||
exclude = ["tools"]
|
||||
exclude = ["tools", "src/openllm/playground"]
|
||||
extend-select = [
|
||||
"B", # flake8-bugbear
|
||||
"I", # isort
|
||||
"G", # flake8-logging-format
|
||||
"D", # pydocstyle
|
||||
"W", # pycodestyle
|
||||
"Q", # flake8-quotes
|
||||
"FA", # flake8-future-annotations
|
||||
"S", # flake8-bandit
|
||||
"TCH", # flake8-type-checking
|
||||
"PLW", # pylint-warning
|
||||
"PLR", # pylint-refactor
|
||||
"PT", # flake8-pytest-style
|
||||
"PYI", # flake8-pyi
|
||||
"PERF", # perflint
|
||||
"FLY", # flynt
|
||||
"RUF", # Ruff-specific rules
|
||||
"YTT", # flake8-2020
|
||||
]
|
||||
fix = true
|
||||
ignore = [
|
||||
# Allow non-abstract empty methods in abstract base classes
|
||||
"B027",
|
||||
# Allow boolean positional values in function calls, like `dict.get(... True)`
|
||||
"FBT003",
|
||||
# Ignore checks for possible passwords
|
||||
"S105",
|
||||
"B027", # Allow non-abstract empty methods in abstract base classes
|
||||
"FBT003", # Allow boolean positional values in function calls, like `dict.get(... True)`
|
||||
"S105", # Ignore checks for possible passwords
|
||||
"S106",
|
||||
"S107",
|
||||
# Ignore complexity
|
||||
"C901",
|
||||
"S603", # ignore subprocess.call
|
||||
"PLR0911",
|
||||
"PLR0912",
|
||||
"PLR0913",
|
||||
"PLR0915",
|
||||
"E501",
|
||||
"E741",
|
||||
"PLR2004", # magic value to use constant
|
||||
"E501", # ignore line length violation
|
||||
"PYI021", # ignore docstring in stubs, as pyright will include docstring in stubs.
|
||||
"D103", # Just missing docstring for magic methods.
|
||||
"D102",
|
||||
"D101",
|
||||
"D100",
|
||||
"TCH004", # don't move runtime import out, just warn about it
|
||||
"RUF012", # mutable attributes to be used with ClassVar
|
||||
"B905", # zip warning about strict, only applicable for 3.10+
|
||||
]
|
||||
line-length = 119
|
||||
target-version = "py311"
|
||||
target-version = "py312"
|
||||
unfixable = [
|
||||
"F401", # Don't touch unused imports, just warn about it.
|
||||
"F401", # Don't touch unused imports, just warn about it.
|
||||
"TCH004", # Don't touch import outside of TYPE_CHECKING block
|
||||
]
|
||||
|
||||
[tool.ruff.flake8-type-checking]
|
||||
exempt-modules = ["typing", "typing_extensions", "."]
|
||||
runtime-evaluated-base-classes = [
|
||||
"pydantic.BaseModel",
|
||||
"openllm._configuration.LLMConfig",
|
||||
"openllm._configuration.GenerationConfig",
|
||||
"openllm._configuration.ModelSettings",
|
||||
]
|
||||
runtime-evaluated-decorators = ["attrs.define", "attrs.frozen"]
|
||||
|
||||
[tool.ruff.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
[tool.ruff.pycodestyle]
|
||||
ignore-overlong-task-comments = true
|
||||
|
||||
[tool.ruff.isort]
|
||||
force-single-line = true
|
||||
known-first-party = ["openllm", "bentoml", 'transformers']
|
||||
lines-after-imports = 2
|
||||
no-lines-before = ["future", "standard-library"]
|
||||
relative-imports-order = "closest-to-furthest"
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
[tool.ruff.flake8-quotes]
|
||||
avoid-escape = false
|
||||
|
||||
[tool.ruff.extend-per-file-ignores]
|
||||
# Tests can use magic values, assertions, and relative imports
|
||||
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
||||
"examples/**/*" = ["D"]
|
||||
"src/openllm/_llm.py" = ["B010", "B009"]
|
||||
"src/openllm/_strategies.py" = ["B904"]
|
||||
"src/openllm/_types.py" = ["E402"]
|
||||
"src/openllm/playground/**/*" = ["E402", "F401"]
|
||||
"tests/**/*" = ["PLR2004", "S101", "TID252"]
|
||||
"src/openllm/cli.py" = ["D301", "S101"]
|
||||
"src/openllm/models/**/*" = ["D106", "S101", "D104"]
|
||||
"src/openllm/playground/**/*" = ["E402", "F401", "PLR", "D"]
|
||||
"src/openllm/utils/dummy_*" = ["D107"]
|
||||
"src/openllm/utils/import_utils.py" = [
|
||||
"PLW0603", # OK to ignore global access here
|
||||
"D105", # magic docstring
|
||||
]
|
||||
"src/openllm_client/runtimes/*" = ["D107"]
|
||||
"tests/**/*" = [
|
||||
"S101",
|
||||
"TID252",
|
||||
"D", # No docstring in tests
|
||||
"PT011", # ignore too broad raises, as it can be use pytest.raises().match()
|
||||
"S307", # Ignore eval(compile) as it is a known script execution
|
||||
]
|
||||
"typings/**/*" = ["D", "F", "E", "PYI002", "I001"]
|
||||
|
||||
[tool.pyright]
|
||||
analysis.useLibraryCodeForTypes = true
|
||||
enableTypeIgnoreComments = true
|
||||
include = ["src/", "tests/", "tools/", "examples/"]
|
||||
exclude = ["src/openllm/playground", "src/openllm/models/"]
|
||||
include = ["src/openllm", "src/openllm_client", "tests/", "tools/", "examples/"]
|
||||
pythonVersion = "3.12"
|
||||
reportMissingImports = "none"
|
||||
reportMissingModuleSource = "warning"
|
||||
reportMissingTypeStubs = "warning"
|
||||
reportMissingTypeStubs = false
|
||||
reportPrivateUsage = "warning"
|
||||
reportUnknownArgumentType = "warning"
|
||||
reportUnknownMemberType = "warning"
|
||||
reportUnknownVariableType = "warning"
|
||||
strictDictionaryInference = true
|
||||
strictListInference = true
|
||||
strictParameterNoneValue = true
|
||||
strictSetInference = true
|
||||
typeCheckingMode = "strict"
|
||||
|
||||
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
omit = [
|
||||
"src/openllm/playground/",
|
||||
"src/openllm/__about__.py",
|
||||
"src/openllm/__main__.py",
|
||||
"src/openllm/tests.py",
|
||||
"src/openllm/utils/dummy_*.py",
|
||||
]
|
||||
source_pkgs = ["openllm"]
|
||||
@@ -255,4 +310,14 @@ source_pkgs = ["openllm"]
|
||||
openllm = ["src/openllm", "*/openllm/src/openllm"]
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if t.TYPE_CHECKING:", "@overload", "# pragma: no cover"]
|
||||
exclude_lines = [
|
||||
"no cov",
|
||||
"pragma: no cover",
|
||||
"if __name__ == .__main__.:",
|
||||
"if t.TYPE_CHECKING:",
|
||||
'if TYPE_CHECKING:',
|
||||
'if typing.TYPE_CHECKING:',
|
||||
'@overload',
|
||||
'@typing.overload',
|
||||
'raise NotImplementedError',
|
||||
]
|
||||
|
||||
@@ -11,9 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
OpenLLM
|
||||
=======
|
||||
"""OpenLLM.
|
||||
|
||||
An open platform for operating large language models in production. Fine-tune, serve,
|
||||
deploy, and monitor any LLMs with ease.
|
||||
@@ -24,7 +22,6 @@ deploy, and monitor any LLMs with ease.
|
||||
* Native integration with BentoML and LangChain for custom LLM apps
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
@@ -39,7 +36,6 @@ if utils.DEBUG:
|
||||
utils.set_debug_mode(True)
|
||||
utils.set_quiet_mode(False)
|
||||
|
||||
utils.configure_logging()
|
||||
logging.basicConfig(level=logging.NOTSET)
|
||||
else:
|
||||
# configuration for bitsandbytes before import
|
||||
@@ -64,16 +60,15 @@ else:
|
||||
_import_structure = {
|
||||
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable"],
|
||||
"_configuration": ["LLMConfig"],
|
||||
"_package": ["build"],
|
||||
"exceptions": [],
|
||||
"_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput"],
|
||||
"utils": [],
|
||||
"utils": ["infer_auto_class"],
|
||||
"models": [],
|
||||
"client": [],
|
||||
"playground": [],
|
||||
"tests": [],
|
||||
"testing": [],
|
||||
"serialisation": ["ggml", "transformers"],
|
||||
"cli": ["start", "start_grpc"],
|
||||
"cli": ["start", "start_grpc", "build", "import_model", "list_models"],
|
||||
# NOTE: models
|
||||
"models.auto": [
|
||||
"AutoConfig",
|
||||
@@ -182,23 +177,23 @@ if t.TYPE_CHECKING:
|
||||
from . import exceptions as exceptions
|
||||
from . import models as models
|
||||
from . import playground as playground
|
||||
from . import tests as tests
|
||||
from . import serialisation as serialisation
|
||||
from . import testing as testing
|
||||
|
||||
# Specific types import
|
||||
from ._configuration import LLMConfig as LLMConfig
|
||||
from ._llm import LLM as LLM
|
||||
from ._llm import LLMRunner as LLMRunner
|
||||
from ._llm import LLMRunnable as LLMRunnable
|
||||
from ._llm import LLMRunner as LLMRunner
|
||||
from ._llm import Runner as Runner
|
||||
from ._package import build as build
|
||||
from ._schema import GenerationInput as GenerationInput
|
||||
from ._schema import GenerationOutput as GenerationOutput
|
||||
from ._schema import MetadataOutput as MetadataOutput
|
||||
from .cli import build as build
|
||||
from .cli import import_model as import_model
|
||||
from .cli import list_models as list_models
|
||||
from .cli import start as start
|
||||
from .cli import start_grpc as start_grpc
|
||||
from .serialisation import ggml as ggml
|
||||
from .serialisation import transformers as transformers
|
||||
from .models.auto import CONFIG_MAPPING as CONFIG_MAPPING
|
||||
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES
|
||||
from .models.auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
|
||||
@@ -213,6 +208,9 @@ if t.TYPE_CHECKING:
|
||||
from .models.opt import OPTConfig as OPTConfig
|
||||
from .models.stablelm import StableLMConfig as StableLMConfig
|
||||
from .models.starcoder import StarCoderConfig as StarCoderConfig
|
||||
from .serialisation import ggml as ggml
|
||||
from .serialisation import transformers as transformers
|
||||
from .utils import infer_auto_class as infer_auto_class
|
||||
|
||||
# NOTE: torch and cpm_kernels
|
||||
try:
|
||||
@@ -286,6 +284,7 @@ else:
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
doc=__doc__,
|
||||
extra_objects={
|
||||
"__version__": __version__,
|
||||
# The below is a special mapping that allows openllm to be used as a dictionary.
|
||||
|
||||
@@ -11,8 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
CLI entrypoint for OpenLLM.
|
||||
"""CLI entrypoint for OpenLLM.
|
||||
|
||||
Usage:
|
||||
openllm --help
|
||||
|
||||
@@ -11,8 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Configuration utilities for OpenLLM. All model configuration will inherit from ``openllm.LLMConfig``.
|
||||
"""Configuration utilities for OpenLLM. All model configuration will inherit from ``openllm.LLMConfig``.
|
||||
|
||||
Highlight feature: Each fields in ``openllm.LLMConfig`` will also automatically generate a environment
|
||||
variable based on its name field.
|
||||
@@ -47,7 +46,6 @@ dynamically during serve, ahead-of-serve or per requests.
|
||||
Refer to ``openllm.LLMConfig`` docstring for more information.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import enum
|
||||
import logging
|
||||
@@ -68,56 +66,48 @@ from deepmerge.merger import Merger
|
||||
import openllm
|
||||
|
||||
from .exceptions import ForbiddenAttributeError
|
||||
from .utils import ENV_VARS_TRUE_VALUES
|
||||
from .utils import LazyType
|
||||
from .utils import ReprMixin, ENV_VARS_TRUE_VALUES
|
||||
from .utils import ReprMixin
|
||||
from .utils import bentoml_cattr
|
||||
from .utils import codegen
|
||||
from .utils import dantic
|
||||
from .utils import first_not_none, field_env_key
|
||||
from .utils import field_env_key
|
||||
from .utils import first_not_none
|
||||
from .utils import lenient_issubclass
|
||||
from .utils import non_intrusive_setattr
|
||||
from .utils import requires_dependencies
|
||||
|
||||
|
||||
if hasattr(t, "Required"):
|
||||
from typing import Required
|
||||
else:
|
||||
from typing_extensions import Required
|
||||
|
||||
if hasattr(t, "NotRequired"):
|
||||
from typing import NotRequired
|
||||
else:
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
if hasattr(t, "dataclass_transform"):
|
||||
from typing import dataclass_transform
|
||||
else:
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# NOTE: We need to do check overload import
|
||||
# so that it can register
|
||||
# correct overloads to typing registry
|
||||
if hasattr(t, "get_overloads"):
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import NotRequired
|
||||
from typing import Required
|
||||
from typing import dataclass_transform
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import NotRequired
|
||||
from typing_extensions import Required
|
||||
from typing_extensions import dataclass_transform
|
||||
from typing_extensions import overload
|
||||
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import click
|
||||
import peft
|
||||
from attr import _CountingAttr # type: ignore
|
||||
from attr import _make_init # type: ignore
|
||||
from attr import _transform_attrs # type: ignore
|
||||
from attr import _CountingAttr
|
||||
from attr import _make_init
|
||||
from attr import _transform_attrs
|
||||
from attr._compat import set_closure_cell
|
||||
|
||||
import transformers
|
||||
from transformers.generation.beam_constraints import Constraint
|
||||
|
||||
from ._types import ClickFunctionWrapper
|
||||
from ._types import F
|
||||
from ._types import O_co
|
||||
from ._types import P
|
||||
from ._types import AnyCallable
|
||||
|
||||
DictStrAny = dict[str, t.Any]
|
||||
ListStr = list[str]
|
||||
@@ -154,10 +144,10 @@ config_merger = Merger(
|
||||
|
||||
# case insensitive, but rename to conform with type
|
||||
class _PeftEnumMeta(enum.EnumMeta):
|
||||
def __getitem__(self, __key: str | t.Any) -> PeftType:
|
||||
def __getitem__(self, __key: str | t.Any) -> enum.Enum:
|
||||
if isinstance(__key, str):
|
||||
__key = inflection.underscore(__key).upper()
|
||||
return super().__getitem__(__key)
|
||||
return self._member_map_[__key]
|
||||
|
||||
|
||||
# vendorred from peft.utils.config.PeftType
|
||||
@@ -171,11 +161,11 @@ class PeftType(enum.Enum, metaclass=_PeftEnumMeta):
|
||||
ADAPTION_PROMPT = "ADAPTION_PROMPT"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value: object) -> PeftType | None:
|
||||
def _missing_(cls, value: object) -> enum.Enum | None:
|
||||
if isinstance(value, str):
|
||||
normalized = inflection.underscore(value).upper()
|
||||
if normalized in cls._member_map_:
|
||||
return cls[normalized]
|
||||
return cls._member_map_[normalized]
|
||||
|
||||
@classmethod
|
||||
def supported(cls) -> set[str]:
|
||||
@@ -184,6 +174,11 @@ class PeftType(enum.Enum, metaclass=_PeftEnumMeta):
|
||||
def to_str(self) -> str:
|
||||
return self.value
|
||||
|
||||
@staticmethod
|
||||
def get(__key: str | t.Any) -> PeftType:
|
||||
"""type-safe getitem."""
|
||||
return t.cast(PeftType, PeftType[__key])
|
||||
|
||||
|
||||
_PEFT_TASK_TYPE_TARGET_MAPPING = {"causal_lm": "CAUSAL_LM", "seq2seq_lm": "SEQ_2_SEQ_LM"}
|
||||
|
||||
@@ -200,14 +195,33 @@ def _adapter_converter(value: AdapterType | str | PeftType | None) -> PeftType:
|
||||
raise ValueError("'AdapterType' cannot be None.")
|
||||
if isinstance(value, PeftType):
|
||||
return value
|
||||
if isinstance(value, str) and value not in PeftType.supported():
|
||||
if value not in PeftType.supported():
|
||||
raise ValueError(f"Given '{value}' is not a supported adapter type.")
|
||||
return PeftType[value]
|
||||
return PeftType.get(value)
|
||||
|
||||
|
||||
@attr.define(slots=True)
|
||||
class FineTuneConfig:
|
||||
"""FineTuneConfig defines a default value for fine-tuning this any given LLM. For example:
|
||||
"""FineTuneConfig defines a default value for fine-tuning this any given LLM.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
class FalconConfig(openllm.LLMConfig):
|
||||
|
||||
__config__ = {
|
||||
"fine_tune_strategies": (
|
||||
{
|
||||
"adapter_type": "lora",
|
||||
"r": 64,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.1,
|
||||
"bias": "none",
|
||||
"target_modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],
|
||||
},
|
||||
),
|
||||
}
|
||||
```
|
||||
|
||||
This is a lower level API that leverage `peft` as well as openllm.LLMConfig to create default
|
||||
and customization
|
||||
@@ -318,8 +332,7 @@ class FineTuneConfig:
|
||||
docs: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> type[FineTuneConfig]:
|
||||
"""A loose codegen to create default subclass for given adapter config type"""
|
||||
|
||||
"""A loose codegen to create default subclass for given adapter config type."""
|
||||
_new_default = {
|
||||
"adapter_type": PeftType[adapter_type],
|
||||
"adapter_config": attrs,
|
||||
@@ -355,8 +368,7 @@ class FineTuneConfig:
|
||||
|
||||
@attr.frozen(slots=True, repr=False)
|
||||
class GenerationConfig(ReprMixin):
|
||||
"""Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``,
|
||||
with some additional validation and environment constructor.
|
||||
"""GenerationConfig is the attrs-compatible version of ``transformers.GenerationConfig``, with some additional validation and environment constructor.
|
||||
|
||||
Note that we always set `do_sample=True`. This class is not designed to be used directly, rather
|
||||
to be used conjunction with LLMConfig. The instance of the generation config can then be accessed
|
||||
@@ -588,7 +600,7 @@ class GenerationConfig(ReprMixin):
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
def __attrs_init__(self, **_: t.Any):
|
||||
def __attrs_init__(self, *args: t.Any, **attrs: t.Any):
|
||||
...
|
||||
|
||||
def __init__(self, *, _internal: bool = False, **attrs: t.Any):
|
||||
@@ -628,6 +640,7 @@ _object_getattribute = object.__getattribute__
|
||||
|
||||
class ModelSettings(t.TypedDict, total=False):
|
||||
"""ModelSettings serve only for typing purposes as this is transcribed into LLMConfig.__config__.
|
||||
|
||||
Note that all fields from this dictionary will then be converted to __openllm_*__ fields in LLMConfig.
|
||||
|
||||
If the field below changes, make sure to run ./tools/update-config-stubs.py to generate correct __getitem__
|
||||
@@ -728,9 +741,9 @@ class _ModelSettingsAttr:
|
||||
service_name: str
|
||||
requirements: t.Optional[ListStr]
|
||||
bettertransformer: bool
|
||||
model_type: t.Literal['causal_lm', 'seq2seq_lm']
|
||||
runtime: t.Literal['transformers', 'ggml']
|
||||
name_type: t.Optional[t.Literal['dasherize', 'lowercase']]
|
||||
model_type: t.Literal["causal_lm", "seq2seq_lm"]
|
||||
runtime: t.Literal["transformers", "ggml"]
|
||||
name_type: t.Optional[t.Literal["dasherize", "lowercase"]]
|
||||
model_name: str
|
||||
start_name: str
|
||||
env: openllm.utils.EnvVarMixin
|
||||
@@ -743,7 +756,6 @@ class _ModelSettingsAttr:
|
||||
|
||||
|
||||
def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]):
|
||||
assert cl_.__config__ is not None, f"'__config__' is required for {cls}."
|
||||
if "generation_class" in cl_.__config__:
|
||||
raise ValueError(
|
||||
"'generation_class' shouldn't be defined in '__config__', rather defining "
|
||||
@@ -813,8 +825,8 @@ bentoml_cattr.register_structure_hook(_ModelSettingsAttr, structure_settings)
|
||||
|
||||
|
||||
def _setattr_class(attr_name: str, value_var: t.Any):
|
||||
"""
|
||||
Use the builtin setattr to set *attr_name* to *value_var*.
|
||||
"""Use the builtin setattr to set *attr_name* to *value_var*.
|
||||
|
||||
We can't use the cached object.__setattr__ since we are setting
|
||||
attributes to a class.
|
||||
|
||||
@@ -828,7 +840,7 @@ def _setattr_class(attr_name: str, value_var: t.Any):
|
||||
def _make_assignment_script(
|
||||
cls: type[LLMConfig], attributes: attr.AttrsInstance, _prefix: t.LiteralString = "openllm"
|
||||
) -> t.Callable[..., None]:
|
||||
"""Generate the assignment script with prefix attributes __openllm_<value>__"""
|
||||
"""Generate the assignment script with prefix attributes __openllm_<value>__."""
|
||||
args: ListStr = []
|
||||
globs: DictStrAny = {
|
||||
"cls": cls,
|
||||
@@ -852,13 +864,14 @@ def _make_assignment_script(
|
||||
_reserved_namespace = {"__config__", "GenerationConfig"}
|
||||
|
||||
|
||||
@dataclass_transform(order_default=True, field_specifiers=(attr.field, dantic.Field))
|
||||
@dataclass_transform(kw_only_default=True, order_default=True, field_specifiers=(attr.field, dantic.Field))
|
||||
def llm_config_transform(cls: type[LLMConfig]) -> type[LLMConfig]:
|
||||
non_intrusive_setattr(
|
||||
cls,
|
||||
"__dataclass_transform__",
|
||||
{
|
||||
"order_default": True,
|
||||
"kw_only_default": True,
|
||||
"field_specifiers": (attr.field, dantic.Field),
|
||||
},
|
||||
)
|
||||
@@ -880,7 +893,7 @@ class _ConfigAttr:
|
||||
# NOTE: The following is handled via __init_subclass__, and is only used for TYPE_CHECKING
|
||||
if t.TYPE_CHECKING:
|
||||
# NOTE: public attributes to override
|
||||
__config__: ModelSettings | None = Field(None)
|
||||
__config__: ModelSettings = Field(None)
|
||||
"""Internal configuration for this LLM model. Each of the field in here will be populated
|
||||
and prefixed with __openllm_<value>__"""
|
||||
GenerationConfig: type = Field(None)
|
||||
@@ -914,7 +927,7 @@ class _ConfigAttr:
|
||||
to create the generation_config argument that can be used throughout the lifecycle.
|
||||
This class will also be managed internally by OpenLLM."""
|
||||
|
||||
def __attrs_init__(self, **attrs: t.Any):
|
||||
def __attrs_init__(self, *args: t.Any, **attrs: t.Any):
|
||||
"""Generated __attrs_init__ for LLMConfig subclass that follows the attrs contract."""
|
||||
|
||||
# NOTE: The following will be populated from __config__ and also
|
||||
@@ -955,14 +968,14 @@ class _ConfigAttr:
|
||||
architecture. By default, we will use BetterTransformer for T5 and StableLM models,
|
||||
and set to False for every other models.
|
||||
"""
|
||||
__openllm_model_type__: t.Literal['causal_lm', 'seq2seq_lm'] = Field(None)
|
||||
__openllm_model_type__: t.Literal["causal_lm", "seq2seq_lm"] = Field(None)
|
||||
"""The model type for this given LLM. By default, it should be causal language modeling.
|
||||
Currently supported 'causal_lm' or 'seq2seq_lm'
|
||||
"""
|
||||
__openllm_runtime__: t.Literal['transformers', 'ggml'] = Field(None)
|
||||
__openllm_runtime__: t.Literal["transformers", "ggml"] = Field(None)
|
||||
"""The runtime to use for this model. Possible values are `transformers` or `ggml`. See
|
||||
LlaMA for more information."""
|
||||
__openllm_name_type__: t.Optional[t.Literal['dasherize', 'lowercase']] = Field(None)
|
||||
__openllm_name_type__: t.Optional[t.Literal["dasherize", "lowercase"]] = Field(None)
|
||||
"""The default name typed for this model. "dasherize" will convert the name to lowercase and
|
||||
replace spaces with dashes. "lowercase" will convert the name to lowercase. If this is not set, then both
|
||||
`model_name` and `start_name` must be specified."""
|
||||
@@ -991,11 +1004,187 @@ class _ConfigAttr:
|
||||
# fmt: on
|
||||
|
||||
|
||||
@attr.define(slots=True)
|
||||
class LLMConfig(_ConfigAttr):
|
||||
class _ConfigBuilder:
|
||||
"""A modified version of attrs internal _ClassBuilder, and should only be called within __init_subclass__ of LLMConfig.
|
||||
|
||||
Where:
|
||||
- has_custom_setattr=True
|
||||
- getstate_setstate=None (config class will always be a slotted class.)
|
||||
- slots=True
|
||||
- auto_attribs=False (We should handle it before _ConfigBuilder is invoked)
|
||||
- cache_hash=False (We don't need to cache the hash code of this object for now.)
|
||||
- collect_by_mro=True (The correct behaviour to resolve inheritance)
|
||||
- field_transformer=codegen.make_env_transformer (We need to transform the field to have env variable)
|
||||
|
||||
It takes `these` arguments as a fully parsed attr.Attribute[t.Any] from __init_subclass__
|
||||
"""
|
||||
``openllm.LLMConfig`` is somewhat a hybrid combination between the performance of `attrs` with the
|
||||
easy-to-use interface that pydantic offer. It lives in between where it allows users to quickly formulate
|
||||
|
||||
__slots__ = (
|
||||
"_cls",
|
||||
"_cls_dict",
|
||||
"_attr_names",
|
||||
"_attrs",
|
||||
"_model_name",
|
||||
"_base_attr_map",
|
||||
"_base_names",
|
||||
"_has_pre_init",
|
||||
"_has_post_init",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cls: type[LLMConfig],
|
||||
these: dict[str, _CountingAttr[t.Any]],
|
||||
auto_attribs: bool = False,
|
||||
kw_only: bool = False,
|
||||
collect_by_mro: bool = True,
|
||||
):
|
||||
attrs, base_attrs, base_attr_map = _transform_attrs(
|
||||
cls,
|
||||
these,
|
||||
auto_attribs,
|
||||
kw_only,
|
||||
collect_by_mro,
|
||||
field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__),
|
||||
)
|
||||
|
||||
self._cls = cls
|
||||
self._model_name = cls.__openllm_model_name__
|
||||
self._cls_dict = dict(cls.__dict__)
|
||||
self._attrs = attrs
|
||||
self._base_names = {a.name for a in base_attrs}
|
||||
self._base_attr_map = base_attr_map
|
||||
self._attr_names = tuple(a.name for a in attrs)
|
||||
self._has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False))
|
||||
self._has_post_init = bool(getattr(cls, "__attrs_post_init__", False))
|
||||
|
||||
self._cls_dict["__attrs_attrs__"] = self._attrs
|
||||
|
||||
def build_class(self) -> type[LLMConfig]:
|
||||
"""Finalize class based on the accumulated configuration.
|
||||
|
||||
Builder cannot be used after calling this method.
|
||||
|
||||
> A difference between this and attrs._ClassBuilder is that we don't
|
||||
> create a new class after constructing all __dict__. This has to do
|
||||
> with recursive called within __init_subclass__
|
||||
"""
|
||||
cd = {
|
||||
k: v for k, v in self._cls_dict.items() if k not in (*tuple(self._attr_names), "__dict__", "__weakref__")
|
||||
}
|
||||
# Traverse the MRO to collect existing slots
|
||||
# and check for an existing __weakref__.
|
||||
existing_slots: DictStrAny = {}
|
||||
weakref_inherited = False
|
||||
for base_cls in self._cls.__mro__[1:-1]:
|
||||
if base_cls.__dict__.get("__weakref__", None) is not None:
|
||||
weakref_inherited = True
|
||||
existing_slots.update(
|
||||
{name: getattr(base_cls, name, codegen._sentinel) for name in getattr(base_cls, "__slots__", [])}
|
||||
)
|
||||
|
||||
base_names = set(self._base_names)
|
||||
names = self._attr_names
|
||||
if (
|
||||
"__weakref__" not in getattr(self._cls, "__slots__", ())
|
||||
and "__weakref__" not in names
|
||||
and not weakref_inherited
|
||||
):
|
||||
names += ("__weakref__",)
|
||||
|
||||
# We only add the names of attributes that aren't inherited.
|
||||
# Setting __slots__ to inherited attributes wastes memory.
|
||||
slot_names = [name for name in names if name not in base_names]
|
||||
# There are slots for attributes from current class
|
||||
# that are defined in parent classes.
|
||||
# As their descriptors may be overridden by a child class,
|
||||
# we collect them here and update the class dict
|
||||
reused_slots = {
|
||||
slot: slot_descriptor for slot, slot_descriptor in existing_slots.items() if slot in slot_names
|
||||
}
|
||||
# We only add the names of attributes that aren't inherited.
|
||||
# Setting __slots__ to inherited attributes wastes memory.
|
||||
# __openllm_extras__ holds additional metadata that might be usefule for users, hence we add it to slots
|
||||
slot_names = [name for name in slot_names if name not in reused_slots]
|
||||
cd.update(reused_slots)
|
||||
cd["__slots__"] = tuple(slot_names)
|
||||
|
||||
cd["__qualname__"] = self._cls.__qualname__
|
||||
|
||||
# We can only patch the class here, rather than instantiate
|
||||
# a new one, since type.__new__ actually will invoke __init_subclass__
|
||||
# and since we use the _ConfigBuilder in __init_subclass__, it will
|
||||
# raise recusion error. See https://peps.python.org/pep-0487/ for more
|
||||
# information on how __init_subclass__ works.
|
||||
for k, value in cd.items():
|
||||
setattr(self._cls, k, value)
|
||||
|
||||
return self.make_closure(self._cls)
|
||||
|
||||
def make_closure(self, cls: type):
|
||||
# The following is a fix for
|
||||
# <https://github.com/python-attrs/attrs/issues/102>.
|
||||
# If a method mentions `__class__` or uses the no-arg super(), the
|
||||
# compiler will bake a reference to the class in the method itself
|
||||
# as `method.__closure__`. Since we replace the class with a
|
||||
# clone, we rewrite these references so it keeps working.
|
||||
for item in cls.__dict__.values():
|
||||
if isinstance(item, (classmethod, staticmethod)):
|
||||
# Class- and staticmethods hide their functions inside.
|
||||
# These might need to be rewritten as well.
|
||||
closure_cells = getattr(item.__func__, "__closure__", None)
|
||||
elif isinstance(item, property):
|
||||
# Workaround for property `super()` shortcut (PY3-only).
|
||||
# There is no universal way for other descriptors.
|
||||
closure_cells = getattr(item.fget, "__closure__", None)
|
||||
else:
|
||||
closure_cells = getattr(item, "__closure__", None)
|
||||
|
||||
if not closure_cells: # Catch None or the empty list.
|
||||
continue
|
||||
for cell in closure_cells:
|
||||
try:
|
||||
match = cell.cell_contents is self._cls
|
||||
except ValueError: # ValueError: Cell is empty
|
||||
pass
|
||||
else:
|
||||
if match:
|
||||
set_closure_cell(cell, cls)
|
||||
|
||||
return llm_config_transform(cls)
|
||||
|
||||
def add_attrs_init(self) -> t.Self:
|
||||
self._cls_dict["__attrs_init__"] = codegen.add_method_dunders(
|
||||
self._cls,
|
||||
_make_init(
|
||||
self._cls,
|
||||
self._attrs,
|
||||
self._has_pre_init,
|
||||
self._has_post_init,
|
||||
False, # frozen
|
||||
True, # slots
|
||||
False, # cache_hash
|
||||
self._base_attr_map,
|
||||
False, # This is not an exception
|
||||
None, # no on_setattr
|
||||
True,
|
||||
),
|
||||
)
|
||||
return self
|
||||
|
||||
def add_repr(self):
|
||||
for key, fn in ReprMixin.__dict__.items():
|
||||
if key in ("__repr__", "__str__", "__repr_name__", "__repr_str__", "__repr_args__"):
|
||||
self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn)
|
||||
self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs} | {"generation_config"})
|
||||
return self
|
||||
|
||||
|
||||
@attr.define(slots=True, init=False)
|
||||
class LLMConfig(_ConfigAttr):
|
||||
"""``openllm.LLMConfig`` is a pydantic-like ``attrs`` interface that offers fast and easy-to-use APIs.
|
||||
|
||||
It lives in between the nice UX of `pydantic` and fast performance of `attrs` where it allows users to quickly formulate
|
||||
a LLMConfig for any LLM without worrying too much about performance. It does a few things:
|
||||
|
||||
- Automatic environment conversion: Each fields will automatically be provisioned with an environment
|
||||
@@ -1077,182 +1266,16 @@ class LLMConfig(_ConfigAttr):
|
||||
),
|
||||
}
|
||||
```
|
||||
|
||||
Future work:
|
||||
- Support pydantic-core as validation backend.
|
||||
"""
|
||||
|
||||
class _ConfigBuilder:
|
||||
"""A modified version of attrs internal _ClassBuilder, should only be called
|
||||
within __init_subclass__ of LLMConfig.
|
||||
|
||||
Where:
|
||||
- has_custom_setattr=True
|
||||
- getstate_setstate=None (config class will always be a slotted class.)
|
||||
- slots=True
|
||||
- auto_attribs=False (We should handle it before _ConfigBuilder is invoked)
|
||||
- cache_hash=False (We don't need to cache the hash code of this object for now.)
|
||||
- collect_by_mro=True (The correct behaviour to resolve inheritance)
|
||||
- field_transformer=codegen.make_env_transformer (We need to transform the field to have env variable)
|
||||
|
||||
It takes `these` arguments as a fully parsed attr.Attribute[t.Any] from __init_subclass__
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"_cls",
|
||||
"_cls_dict",
|
||||
"_attr_names",
|
||||
"_attrs",
|
||||
"_model_name",
|
||||
"_base_attr_map",
|
||||
"_base_names",
|
||||
"_has_pre_init",
|
||||
"_has_post_init",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cls: type[LLMConfig],
|
||||
these: dict[str, _CountingAttr[t.Any]],
|
||||
auto_attribs: bool = False,
|
||||
kw_only: bool = False,
|
||||
collect_by_mro: bool = True,
|
||||
):
|
||||
attrs, base_attrs, base_attr_map = _transform_attrs(
|
||||
cls,
|
||||
these,
|
||||
auto_attribs,
|
||||
kw_only,
|
||||
collect_by_mro,
|
||||
field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__),
|
||||
)
|
||||
|
||||
self._cls = cls
|
||||
self._model_name = cls.__openllm_model_name__
|
||||
self._cls_dict = dict(cls.__dict__)
|
||||
self._attrs = attrs
|
||||
self._base_names = {a.name for a in base_attrs}
|
||||
self._base_attr_map = base_attr_map
|
||||
self._attr_names = tuple(a.name for a in attrs)
|
||||
self._has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False))
|
||||
self._has_post_init = bool(getattr(cls, "__attrs_post_init__", False))
|
||||
|
||||
self._cls_dict["__attrs_attrs__"] = self._attrs
|
||||
|
||||
def build_class(self) -> type[LLMConfig]:
|
||||
"""
|
||||
Finalize class based on the accumulated configuration.
|
||||
|
||||
Builder cannot be used after calling this method.
|
||||
|
||||
> A difference between this and attrs._ClassBuilder is that we don't
|
||||
> create a new class after constructing all __dict__. This has to do
|
||||
> with recursive called within __init_subclass__
|
||||
"""
|
||||
cd = {
|
||||
k: v
|
||||
for k, v in self._cls_dict.items()
|
||||
if k not in tuple(self._attr_names) + ("__dict__", "__weakref__")
|
||||
}
|
||||
# Traverse the MRO to collect existing slots
|
||||
# and check for an existing __weakref__.
|
||||
existing_slots: DictStrAny = {}
|
||||
weakref_inherited = False
|
||||
for base_cls in self._cls.__mro__[1:-1]:
|
||||
if base_cls.__dict__.get("__weakref__", None) is not None:
|
||||
weakref_inherited = True
|
||||
existing_slots.update(
|
||||
{name: getattr(base_cls, name, codegen._sentinel) for name in getattr(base_cls, "__slots__", [])}
|
||||
)
|
||||
|
||||
base_names = set(self._base_names)
|
||||
names = self._attr_names
|
||||
if (
|
||||
"__weakref__" not in getattr(self._cls, "__slots__", ())
|
||||
and "__weakref__" not in names
|
||||
and not weakref_inherited
|
||||
):
|
||||
names += ("__weakref__",)
|
||||
|
||||
# We only add the names of attributes that aren't inherited.
|
||||
# Setting __slots__ to inherited attributes wastes memory.
|
||||
slot_names = [name for name in names if name not in base_names]
|
||||
# There are slots for attributes from current class
|
||||
# that are defined in parent classes.
|
||||
# As their descriptors may be overridden by a child class,
|
||||
# we collect them here and update the class dict
|
||||
reused_slots = {
|
||||
slot: slot_descriptor for slot, slot_descriptor in existing_slots.items() if slot in slot_names
|
||||
}
|
||||
# We only add the names of attributes that aren't inherited.
|
||||
# Setting __slots__ to inherited attributes wastes memory.
|
||||
# __openllm_extras__ holds additional metadata that might be usefule for users, hence we add it to slots
|
||||
slot_names = [name for name in slot_names if name not in reused_slots]
|
||||
cd.update(reused_slots)
|
||||
cd["__slots__"] = tuple(slot_names)
|
||||
|
||||
for k, value in cd.items():
|
||||
setattr(self._cls, k, value)
|
||||
|
||||
# The following is a fix for
|
||||
# <https://github.com/python-attrs/attrs/issues/102>.
|
||||
# If a method mentions `__class__` or uses the no-arg super(), the
|
||||
# compiler will bake a reference to the class in the method itself
|
||||
# as `method.__closure__`. Since we replace the class with a
|
||||
# clone, we rewrite these references so it keeps working.
|
||||
for item in self._cls.__dict__.values():
|
||||
if isinstance(item, (classmethod, staticmethod)):
|
||||
# Class- and staticmethods hide their functions inside.
|
||||
# These might need to be rewritten as well.
|
||||
closure_cells = getattr(item.__func__, "__closure__", None)
|
||||
elif isinstance(item, property):
|
||||
# Workaround for property `super()` shortcut (PY3-only).
|
||||
# There is no universal way for other descriptors.
|
||||
closure_cells = getattr(item.fget, "__closure__", None)
|
||||
else:
|
||||
closure_cells = getattr(item, "__closure__", None)
|
||||
|
||||
if not closure_cells: # Catch None or the empty list.
|
||||
continue
|
||||
for cell in closure_cells:
|
||||
try:
|
||||
match = cell.cell_contents is self._cls
|
||||
except ValueError: # ValueError: Cell is empty
|
||||
pass
|
||||
else:
|
||||
if match:
|
||||
set_closure_cell(cell, self._cls)
|
||||
|
||||
return llm_config_transform(self._cls)
|
||||
|
||||
def add_attrs_init(self) -> t.Self:
|
||||
self._cls_dict["__attrs_init__"] = codegen.add_method_dunders(
|
||||
self._cls,
|
||||
_make_init(
|
||||
self._cls,
|
||||
self._attrs,
|
||||
self._has_pre_init,
|
||||
self._has_post_init,
|
||||
False, # frozen
|
||||
True, # slots
|
||||
False, # cache_hash
|
||||
self._base_attr_map,
|
||||
False, # This is not an exception
|
||||
None, # no on_setattr
|
||||
attrs_init=True,
|
||||
),
|
||||
)
|
||||
return self
|
||||
|
||||
def add_repr(self):
|
||||
for key, fn in ReprMixin.__dict__.items():
|
||||
if key in ("__repr__", "__str__", "__repr_name__", "__repr_str__", "__repr_args__"):
|
||||
self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn)
|
||||
self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs} | {"generation_config"})
|
||||
return self
|
||||
|
||||
def __init_subclass__(cls: type[LLMConfig]):
|
||||
"""The purpose of this __init_subclass__ is that we want all subclass of LLMConfig
|
||||
to adhere to the attrs contract, and have pydantic-like interface. This means we will
|
||||
construct all fields and metadata and hack into how attrs use some of the 'magic' construction
|
||||
to generate the fields.
|
||||
"""The purpose of this ``__init_subclass__`` is to offer pydantic UX while adhering to attrs contract.
|
||||
|
||||
This means we will construct all fields and metadata and hack into
|
||||
how attrs use some of the 'magic' construction to generate the fields.
|
||||
|
||||
It also does a few more extra features: It also generate all __openllm_*__ config from
|
||||
ModelSettings (derived from __config__) to the class.
|
||||
@@ -1261,7 +1284,7 @@ class LLMConfig(_ConfigAttr):
|
||||
logger.warning("LLMConfig subclass should end with 'Config'. Updating to %sConfig", cls.__name__)
|
||||
cls.__name__ = f"{cls.__name__}Config"
|
||||
|
||||
if not hasattr(cls, "__config__") or cls.__config__ is None:
|
||||
if not hasattr(cls, "__config__"):
|
||||
raise RuntimeError("Given LLMConfig must have '__config__' that is not None defined.")
|
||||
|
||||
# auto assignment attributes generated from __config__ after create the new slot class.
|
||||
@@ -1320,7 +1343,7 @@ class LLMConfig(_ConfigAttr):
|
||||
a.name for a in attr.fields(cls.__openllm_generation_class__)
|
||||
}
|
||||
|
||||
cls = cls._ConfigBuilder(cls, these).add_attrs_init().add_repr().build_class()
|
||||
cls = _ConfigBuilder(cls, these).add_attrs_init().add_repr().build_class()
|
||||
|
||||
# Finally, resolve the types
|
||||
if getattr(cls, "__attrs_types_resolved__", None) != cls:
|
||||
@@ -1398,11 +1421,11 @@ class LLMConfig(_ConfigAttr):
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["bettertransformer"] = ...) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["model_type"] = ...) -> t.Literal['causal_lm', 'seq2seq_lm']: ...
|
||||
def __getitem__(self, item: t.Literal["model_type"] = ...) -> t.Literal["causal_lm", "seq2seq_lm"]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["runtime"] = ...) -> t.Literal['transformers', 'ggml']: ...
|
||||
def __getitem__(self, item: t.Literal["runtime"] = ...) -> t.Literal["transformers", "ggml"]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["name_type"] = ...) -> t.Optional[t.Literal['dasherize', 'lowercase']]: ...
|
||||
def __getitem__(self, item: t.Literal["name_type"] = ...) -> t.Optional[t.Literal["dasherize", "lowercase"]]: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["model_name"] = ...) -> str: ...
|
||||
@overload
|
||||
@@ -1516,7 +1539,7 @@ class LLMConfig(_ConfigAttr):
|
||||
# fmt: on
|
||||
|
||||
def __getitem__(self, item: t.LiteralString | t.Any = None) -> t.Any:
|
||||
"""Allowing access LLMConfig as a dictionary. The order will always evaluate as
|
||||
"""Allowing access LLMConfig as a dictionary. The order will always evaluate as.
|
||||
|
||||
__openllm_*__ > self.key > self.generation_config > self['fine_tune_strategies'] > __openllm_extras__
|
||||
|
||||
@@ -1599,7 +1622,6 @@ class LLMConfig(_ConfigAttr):
|
||||
**attrs: The attributes to be added to the new class. This will override
|
||||
any existing attributes with the same name.
|
||||
"""
|
||||
assert cls.__config__ is not None, "Cannot derivate a LLMConfig without __config__"
|
||||
_new_cfg = {k: v for k, v in attrs.items() if k in attr.fields_dict(_ModelSettingsAttr)}
|
||||
attrs = {k: v for k, v in attrs.items() if k not in _new_cfg}
|
||||
new_cls = types.new_class(
|
||||
@@ -1642,14 +1664,12 @@ class LLMConfig(_ConfigAttr):
|
||||
try:
|
||||
attrs = orjson.loads(json_str)
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise openllm.exceptions.ValidationError(f"Failed to load JSON: {err}")
|
||||
raise openllm.exceptions.ValidationError(f"Failed to load JSON: {err}") from None
|
||||
return bentoml_cattr.structure(attrs, cls)
|
||||
|
||||
@classmethod
|
||||
def model_construct_env(cls, **attrs: t.Any) -> t.Self:
|
||||
"""A helpers that respect configuration values that
|
||||
sets from environment variables for any given configuration class.
|
||||
"""
|
||||
"""A helpers that respect configuration values environment variables."""
|
||||
attrs = {k: v for k, v in attrs.items() if v is not None}
|
||||
|
||||
model_config = cls.__openllm_env__.config
|
||||
@@ -1696,7 +1716,7 @@ class LLMConfig(_ConfigAttr):
|
||||
return self.model_construct_env(**llm_config_attrs), {k: v for k, v in attrs.items() if k not in key_to_remove}
|
||||
|
||||
@overload
|
||||
def to_generation_config(self, return_as_dict: t.Literal[False] = ...) -> transformers.GenerationConfig:
|
||||
def to_generation_config(self, return_as_dict: t.Literal[False] = False) -> transformers.GenerationConfig:
|
||||
...
|
||||
|
||||
@overload
|
||||
@@ -1708,22 +1728,12 @@ class LLMConfig(_ConfigAttr):
|
||||
return config.to_dict() if return_as_dict else config
|
||||
|
||||
@classmethod
|
||||
@overload
|
||||
def to_click_options(
|
||||
cls, f: t.Callable[..., openllm.LLMConfig]
|
||||
) -> F[P, ClickFunctionWrapper[..., openllm.LLMConfig]]:
|
||||
...
|
||||
def to_click_options(cls, f: AnyCallable) -> click.Command:
|
||||
"""Convert current configuration to click options.
|
||||
|
||||
@classmethod
|
||||
@overload
|
||||
def to_click_options(cls, f: t.Callable[P, O_co]) -> F[P, ClickFunctionWrapper[P, O_co]]:
|
||||
...
|
||||
This can be used as a decorator for click commands.
|
||||
|
||||
@classmethod
|
||||
def to_click_options(cls, f: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
|
||||
"""
|
||||
Convert current model to click options. This can be used as a decorator for click commands.
|
||||
Note that the identifier for all LLMConfig will be prefixed with '<model_name>_*', and the generation config
|
||||
> **Note**: that the identifier for all LLMConfig will be prefixed with '<model_name>_*', and the generation config
|
||||
will be prefixed with '<model_name>_generation_*'.
|
||||
"""
|
||||
for name, field in attr.fields_dict(cls.__openllm_generation_class__).items():
|
||||
@@ -1769,8 +1779,7 @@ bentoml_cattr.register_unstructure_hook_factory(
|
||||
|
||||
|
||||
def structure_llm_config(data: DictStrAny, cls: type[LLMConfig]) -> LLMConfig:
|
||||
"""
|
||||
Structure a dictionary to a LLMConfig object.
|
||||
"""Structure a dictionary to a LLMConfig object.
|
||||
|
||||
Essentially, if the given dictionary contains a 'generation_config' key, then we will
|
||||
use it for LLMConfig.generation_config
|
||||
|
||||
@@ -14,14 +14,15 @@
|
||||
|
||||
"""Generation utilities to be reused throughout."""
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
class StopSequenceCriteria(transformers.StoppingCriteria):
|
||||
"""This class used to stop generation when a seq of tokens are met.
|
||||
|
||||
@@ -42,6 +43,6 @@ class StopSequenceCriteria(transformers.StoppingCriteria):
|
||||
|
||||
|
||||
class StopOnTokens(transformers.StoppingCriteria):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: t.Any) -> bool:
|
||||
stop_ids = {50278, 50279, 50277, 1, 0}
|
||||
return input_ids[0][-1] in stop_ids
|
||||
return t.cast(int, input_ids[0][-1]) in stop_ids
|
||||
|
||||
1022
src/openllm/_llm.py
1022
src/openllm/_llm.py
File diff suppressed because it is too large
Load Diff
@@ -11,17 +11,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Any build-related utilities. This is used for CI.
|
||||
"""Build-related utilities. Some of these utilities are mainly used for 'openllm.build'.
|
||||
|
||||
These utilities will stay internal, and its API can be changed or updated without backward-compatibility.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import typing as t
|
||||
from pathlib import Path
|
||||
|
||||
@@ -34,7 +31,6 @@ from simple_di import Provide
|
||||
from simple_di import inject
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.bento.build_config import BentoBuildConfig
|
||||
from bentoml._internal.bento.build_config import DockerOptions
|
||||
from bentoml._internal.bento.build_config import PythonOptions
|
||||
@@ -56,7 +52,7 @@ from .utils import resolve_user_filepath
|
||||
if t.TYPE_CHECKING:
|
||||
from fs.base import FS
|
||||
|
||||
from bentoml._internal.bento import BentoStore
|
||||
import openllm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -135,7 +131,8 @@ def construct_python_options(
|
||||
env: EnvVarMixin = llm.config["env"]
|
||||
framework_envvar = env["framework_value"]
|
||||
if framework_envvar == "flax":
|
||||
assert is_flax_available(), f"Flax is not available, while {env.framework} is set to 'flax'"
|
||||
if not is_flax_available():
|
||||
raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'")
|
||||
packages.extend(
|
||||
[
|
||||
handle_package_version("flax", has_dockerfile_template),
|
||||
@@ -144,7 +141,8 @@ def construct_python_options(
|
||||
]
|
||||
)
|
||||
elif framework_envvar == "tf":
|
||||
assert is_tf_available(), f"TensorFlow is not available, while {env.framework} is set to 'tf'"
|
||||
if not is_tf_available():
|
||||
raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'")
|
||||
candidates = (
|
||||
"tensorflow",
|
||||
"tensorflow-cpu",
|
||||
@@ -170,7 +168,8 @@ def construct_python_options(
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
pass
|
||||
else:
|
||||
assert is_torch_available(), "PyTorch is not available. Make sure to have it locally installed."
|
||||
if not is_torch_available():
|
||||
raise ValueError("PyTorch is not available. Make sure to have it locally installed.")
|
||||
packages.extend([handle_package_version("torch", has_dockerfile_template)])
|
||||
|
||||
wheels: list[str] = []
|
||||
@@ -206,7 +205,7 @@ def construct_docker_options(
|
||||
"OPENLLM_MODEL": llm.config["model_name"],
|
||||
"OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'",
|
||||
"BENTOML_DEBUG": str(get_debug_mode()),
|
||||
"BENTOML_CONFIG_OPTIONS": _bentoml_config_options,
|
||||
"BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'",
|
||||
}
|
||||
|
||||
if adapter_map:
|
||||
@@ -257,7 +256,8 @@ def create_bento(
|
||||
logger.info("Building Bento for '%s'", llm.config["start_name"])
|
||||
|
||||
if adapter_map is not None:
|
||||
assert build_ctx is not None, "build_ctx is required when 'adapter_map' is not None"
|
||||
if build_ctx is None:
|
||||
raise ValueError("build_ctx is required when 'adapter_map' is not None")
|
||||
updated_mapping: dict[str, str | None] = {}
|
||||
for adapter_id, name in adapter_map.items():
|
||||
try:
|
||||
@@ -321,7 +321,7 @@ def create_bento(
|
||||
# new behaviour with BentoML models
|
||||
model = _model_store.get(f"{model_framework}-{model_type}")
|
||||
except bentoml.exceptions.NotFound:
|
||||
raise OpenLLMException(f"Failed to find models for {llm.config['start_name']}")
|
||||
raise OpenLLMException(f"Failed to find models for {llm.config['start_name']}") from None
|
||||
|
||||
# NOTE: the model_id_path here are only used for setting this environment variable within the container
|
||||
# built with for BentoLLM.
|
||||
@@ -330,10 +330,12 @@ def create_bento(
|
||||
with open(service_path, "r") as f:
|
||||
service_contents = f.readlines()
|
||||
|
||||
rel_path = f"../models/{model.tag.path()}"
|
||||
|
||||
for it in service_contents:
|
||||
if codegen.OPENLLM_MODEL_ID in it:
|
||||
service_contents[service_contents.index(it)] = (
|
||||
codegen.ModelIdFormatter(str(model.tag)).vformat(it)[: -(len(codegen.OPENLLM_MODEL_ID) + 3)] + "\n"
|
||||
codegen.ModelIdFormatter(rel_path).vformat(it)[: -(len(codegen.OPENLLM_MODEL_ID) + 3)] + "\n"
|
||||
)
|
||||
if "__bento_name__" in it:
|
||||
service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag))
|
||||
@@ -346,70 +348,3 @@ def create_bento(
|
||||
bento._fs.writetext(service_fs_path, script)
|
||||
|
||||
return bento.save()
|
||||
|
||||
|
||||
@inject
|
||||
def build(
|
||||
model_name: str,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
model_version: str | None = None,
|
||||
quantize: t.Literal["int8", "int4", "gptq"] | None = None,
|
||||
bettertransformer: bool | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
build_ctx: str | None = None,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
workers_per_resource: int | float | None = None,
|
||||
overwrite_existing_bento: bool = False,
|
||||
runtime: t.Literal["ggml", "transformers"] = "transformers",
|
||||
dockerfile_template: str | None = None,
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
) -> bentoml.Bento:
|
||||
"""Package a LLM into a Bento.
|
||||
|
||||
The LLM will be built into a BentoService with the following structure:
|
||||
if quantize is passed, it will instruct the model to be quantized dynamically during serving time.
|
||||
if bettertransformer is passed, it will instruct the model to use BetterTransformer during serving time.
|
||||
|
||||
Other parameters including model_name, model_id and attrs will be passed to the LLM class itself.
|
||||
"""
|
||||
args = [sys.executable, "-m", "openllm", "build", model_name, "--machine", "--runtime", runtime]
|
||||
|
||||
if quantize and bettertransformer:
|
||||
raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.")
|
||||
|
||||
if quantize:
|
||||
args.extend(["--quantize", quantize])
|
||||
if bettertransformer:
|
||||
args.append("--bettertransformer")
|
||||
|
||||
if model_id:
|
||||
args.extend(["--model-id", model_id])
|
||||
if build_ctx:
|
||||
args.extend(["--build-ctx", build_ctx])
|
||||
if extra_dependencies:
|
||||
args.extend([f"--enable-features={f}" for f in extra_dependencies])
|
||||
if workers_per_resource:
|
||||
args.extend(["--workers-per-resource", str(workers_per_resource)])
|
||||
if overwrite_existing_bento:
|
||||
args.append("--overwrite")
|
||||
if adapter_map:
|
||||
args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version:
|
||||
args.extend(["--model-version", model_version])
|
||||
if dockerfile_template:
|
||||
args.extend(["--dockerfile-template", dockerfile_template])
|
||||
|
||||
try:
|
||||
output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd())
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error("Exception caught while building %s", model_name, exc_info=e)
|
||||
if e.stderr:
|
||||
raise OpenLLMException(e.stderr.decode("utf-8")) from None
|
||||
raise OpenLLMException(str(e)) from None
|
||||
# NOTE: This usually only concern BentoML devs.
|
||||
pattern = r"^__tag__:[^:\n]+:[^:\n]+"
|
||||
matched = re.search(pattern, output.decode("utf-8").strip(), re.MULTILINE)
|
||||
assert matched is not None, f"Failed to find tag from output: {output}"
|
||||
_, _, tag = matched.group(0).partition(":")
|
||||
return bentoml.get(tag, _bento_store=bento_store)
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
import typing as t
|
||||
|
||||
@@ -20,10 +19,10 @@ import typing as t
|
||||
class PromptFormatter(string.Formatter):
|
||||
"""This PromptFormatter is largely based on langchain's implementation."""
|
||||
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> str:
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.LiteralString:
|
||||
if len(args) > 0:
|
||||
raise ValueError("Positional arguments are not supported")
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
return t.cast("t.LiteralString", super().vformat(format_string, args, kwargs))
|
||||
|
||||
def check_unused_args(
|
||||
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
|
||||
|
||||
@@ -11,3 +11,92 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from .utils import LazyLoader
|
||||
from .utils import is_bitsandbytes_available
|
||||
from .utils import is_transformers_supports_kbit
|
||||
from .utils import pkg
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
import openllm
|
||||
import transformers
|
||||
|
||||
from ._types import DictStrAny
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
transformers = LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QuantiseMode = t.Literal["int8", "int4", "gptq"]
|
||||
|
||||
|
||||
def infer_quantisation_config(
|
||||
cls: type[openllm.LLM[t.Any, t.Any]], quantise: QuantiseMode, **attrs: t.Any
|
||||
) -> tuple[transformers.BitsAndBytesConfig | t.Any, DictStrAny]:
|
||||
# 8 bit configuration
|
||||
int8_threshold = attrs.pop("llm_int8_threshhold", 6.0)
|
||||
int8_enable_fp32_cpu_offload = attrs.pop("llm_int8_enable_fp32_cpu_offload", False)
|
||||
int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None)
|
||||
int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False)
|
||||
|
||||
def create_int8_config(int8_skip_modules: list[str] | None):
|
||||
if int8_skip_modules is None:
|
||||
int8_skip_modules = []
|
||||
if "lm_head" not in int8_skip_modules and cls.config_class.__openllm_model_type__ == "causal_lm":
|
||||
logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__)
|
||||
int8_skip_modules.append("lm_head")
|
||||
return transformers.BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload,
|
||||
llm_int8_threshhold=int8_threshold,
|
||||
llm_int8_skip_modules=int8_skip_modules,
|
||||
llm_int8_has_fp16_weight=int8_has_fp16_weight,
|
||||
)
|
||||
|
||||
# 4 bit configuration
|
||||
int4_compute_dtype = attrs.pop("bnb_4bit_compute_dtype", torch.bfloat16)
|
||||
int4_quant_type = attrs.pop("bnb_4bit_quant_type", "nf4")
|
||||
int4_use_double_quant = attrs.pop("bnb_4bit_use_double_quant", True)
|
||||
|
||||
# NOTE: Quantization setup
|
||||
# quantize is a openllm.LLM feature, where we can quantize the model
|
||||
# with bitsandbytes or quantization aware training.
|
||||
if not is_bitsandbytes_available():
|
||||
raise RuntimeError(
|
||||
"Quantization requires bitsandbytes to be installed. Make "
|
||||
"sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'"
|
||||
)
|
||||
if quantise == "int8":
|
||||
quantisation_config = create_int8_config(int8_skip_modules)
|
||||
elif quantise == "int4":
|
||||
if is_transformers_supports_kbit():
|
||||
quantisation_config = transformers.BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=int4_compute_dtype,
|
||||
bnb_4bit_quant_type=int4_quant_type,
|
||||
bnb_4bit_use_double_quant=int4_use_double_quant,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"'quantize' is set to int4, while the current transformers version %s does not support "
|
||||
"k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore "
|
||||
"make sure to install the latest version of transformers either via PyPI or "
|
||||
"from git source: 'pip install git+https://github.com/huggingface/transformers'.",
|
||||
pkg.pkg_version_info("transformers"),
|
||||
)
|
||||
logger.warning("OpenLLM will fallback to 8-bit quantization.")
|
||||
quantisation_config = create_int8_config(int8_skip_modules)
|
||||
elif quantise == "gptq":
|
||||
# TODO: support GPTQ loading quantization
|
||||
raise NotImplementedError("GPTQ is not supported yet.")
|
||||
else:
|
||||
raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantise} instead.")
|
||||
|
||||
return quantisation_config, attrs
|
||||
|
||||
@@ -11,11 +11,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Schema definition for OpenLLM. This can be use for client interaction.
|
||||
"""
|
||||
"""Schema definition for OpenLLM. This can be use for client interaction."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import typing as t
|
||||
|
||||
@@ -23,6 +20,8 @@ import attr
|
||||
import inflection
|
||||
|
||||
import openllm
|
||||
from openllm._configuration import GenerationConfig
|
||||
from openllm.utils import bentoml_cattr
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -79,6 +78,14 @@ class GenerationOutput:
|
||||
configuration: t.Dict[str, t.Any]
|
||||
"""A mapping of configuration values for given system."""
|
||||
|
||||
@property
|
||||
def marshaled_config(self) -> GenerationConfig:
|
||||
return bentoml_cattr.structure(self.configuration, GenerationConfig)
|
||||
|
||||
@property
|
||||
def unmarshaled(self) -> dict[str, t.Any]:
|
||||
return bentoml_cattr.unstructure(self)
|
||||
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class MetadataOutput:
|
||||
|
||||
@@ -12,8 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
The service definition for running any LLMService.
|
||||
"""The service definition for running any LLMService.
|
||||
|
||||
Note that the line `model = ...` is a special line and should not be modified. This will be handled by openllm
|
||||
internally to generate the correct model service when bundling the LLM to a Bento.
|
||||
@@ -22,7 +21,6 @@ This will ensure that 'bentoml serve llm-bento' will work accordingly.
|
||||
The generation code lives under utils/codegen.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing as t
|
||||
import warnings
|
||||
@@ -136,7 +134,7 @@ async def hf_agent(request: Request) -> Response:
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None
|
||||
|
||||
stop = input_data.parameters.pop("stop", "\n")
|
||||
stop = input_data.parameters.pop("stop", ["\n"])
|
||||
try:
|
||||
resp = await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters)
|
||||
return JSONResponse(resp, status_code=200)
|
||||
@@ -150,9 +148,11 @@ svc.mount_asgi_app(hf_app, path="/hf")
|
||||
|
||||
|
||||
async def list_adapter_v1(_: Request) -> Response:
|
||||
res = runner.peft_adapters
|
||||
if res["success"]:
|
||||
res["result"] = {k: v.to_dict() for k, v in res["result"].items()}
|
||||
res: dict[str, t.Any] = {}
|
||||
if runner.peft_adapters["success"] is True:
|
||||
res["result"] = {k: v.to_dict() for k, v in runner.peft_adapters["result"].items()}
|
||||
res["success"] = runner.peft_adapters["success"]
|
||||
res["error_msg"] = runner.peft_adapters["error_msg"]
|
||||
return JSONResponse(res, status_code=200)
|
||||
|
||||
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -23,19 +21,19 @@ import typing as t
|
||||
|
||||
import psutil
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
from bentoml._internal.resource import Resource
|
||||
from bentoml._internal.resource import get_resource
|
||||
from bentoml._internal.resource import system_resources
|
||||
from bentoml._internal.runner.strategy import THREAD_ENVS
|
||||
from bentoml._internal.runner.strategy import Strategy
|
||||
|
||||
from .utils import LazyType
|
||||
from .exceptions import OpenLLMException
|
||||
from .utils import ReprMixin
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
|
||||
ListIntStr = list[int | str]
|
||||
else:
|
||||
ListIntStr = list
|
||||
@@ -43,28 +41,46 @@ else:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AmdGpuResource(Resource[t.List[int]], resource_id="amd.com/gpu"):
|
||||
class AmdGpuResource(Resource[t.List[str]], resource_id="amd.com/gpu"):
|
||||
@classmethod
|
||||
def from_spec(cls, spec: int | str | list[str | int]) -> list[int]:
|
||||
if not isinstance(spec, (int, str)) and not LazyType(ListIntStr).isinstance(spec):
|
||||
def from_spec(cls, spec: t.Any) -> list[str]:
|
||||
if not isinstance(spec, (int, str, list)):
|
||||
raise TypeError("AMD GPU device IDs must be int, str or a list specifing the exact GPUs to use.")
|
||||
|
||||
try:
|
||||
if isinstance(spec, int):
|
||||
if spec == -1:
|
||||
return []
|
||||
if spec < -1:
|
||||
raise ValueError
|
||||
return list(range(spec))
|
||||
return [str(i) for i in range(spec)]
|
||||
elif isinstance(spec, str):
|
||||
return cls.from_spec(int(spec))
|
||||
try:
|
||||
return cls.from_spec(int(spec))
|
||||
except ValueError:
|
||||
if spec.startswith("GPU"):
|
||||
return [spec]
|
||||
raise ValueError
|
||||
else:
|
||||
return [int(x) for x in spec]
|
||||
return [str(x) for x in spec]
|
||||
except ValueError:
|
||||
raise openllm.exceptions.OpenLLMException(f"Invalid AMD GPU resource limit '{spec}'. ")
|
||||
raise OpenLLMException(f"Invalid AMD GPU resource limit '{spec}'.")
|
||||
|
||||
@classmethod # type: ignore (overload)
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def from_system(cls) -> list[int]:
|
||||
@classmethod
|
||||
def from_system(cls) -> list[str]:
|
||||
"""Retrieve AMD GPU from system, currently only supports on Linux.
|
||||
This assumes that ROCm is setup correctly."""
|
||||
|
||||
This assumes that ROCm is setup correctly.
|
||||
"""
|
||||
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
|
||||
if cuda_visible_devices in ("", "-1"):
|
||||
return []
|
||||
if cuda_visible_devices is not None:
|
||||
cuda_visible_devices = cuda_visible_devices.split(",")
|
||||
if "-1" in cuda_visible_devices:
|
||||
cuda_visible_devices = cuda_visible_devices[: cuda_visible_devices.index("-1")]
|
||||
return cuda_visible_devices
|
||||
|
||||
if not psutil.LINUX:
|
||||
logger.debug("AMD GPU resource is only supported on Linux.")
|
||||
return []
|
||||
@@ -84,7 +100,7 @@ class AmdGpuResource(Resource[t.List[int]], resource_id="amd.com/gpu"):
|
||||
num = c_uint32(0)
|
||||
ret = rocmsmi.rsmi_num_monitor_devices(byref(num))
|
||||
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS:
|
||||
return list(range(num.value))
|
||||
return [str(i) for i in range(num.value)]
|
||||
return []
|
||||
except Exception as err:
|
||||
logger.debug("Failed to setup AMD GPU resource: %s", err)
|
||||
@@ -93,18 +109,22 @@ class AmdGpuResource(Resource[t.List[int]], resource_id="amd.com/gpu"):
|
||||
sys.path.remove("/opt/rocm/libexec/rocm_smi")
|
||||
|
||||
@classmethod
|
||||
def validate(cls, val: list[int]):
|
||||
if any(gpu_index < 0 for gpu_index in val):
|
||||
raise openllm.exceptions.OpenLLMException(f"Negative GPU device in {val}.")
|
||||
if any(gpu_index >= len(cls.from_system()) for gpu_index in val):
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"GPU device index in {val} is greater than the system available: {cls.from_system()}"
|
||||
)
|
||||
def validate(cls, val: list[str]):
|
||||
for gpu_index_or_literal in val:
|
||||
try:
|
||||
idx = int(gpu_index_or_literal)
|
||||
except ValueError:
|
||||
raise OpenLLMException(f"Invalid AMD GPU device index: {val}")
|
||||
if int(idx) < 0:
|
||||
raise OpenLLMException(f"Negative GPU device in {val}.")
|
||||
if int(idx) >= len(cls.from_system()):
|
||||
raise OpenLLMException(
|
||||
f"GPU device index in {val} is greater than the system available: {cls.from_system()}"
|
||||
)
|
||||
|
||||
|
||||
class CascadingResourceStrategy(Strategy, ReprMixin):
|
||||
"""This is rather an extension of bentoml._internal.runner.strategy.DefaultStrategy
|
||||
where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource.
|
||||
"""This is extends the default BentoML strategy where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource.
|
||||
|
||||
It also respect CUDA_VISIBLE_DEVICES for both AMD and NVIDIA GPU.
|
||||
See https://rocm.docs.amd.com/en/develop/understand/gpu_isolation.html#cuda-visible-devices
|
||||
@@ -147,9 +167,9 @@ class CascadingResourceStrategy(Strategy, ReprMixin):
|
||||
)
|
||||
|
||||
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
|
||||
if isinstance(workers_per_resource, float):
|
||||
if isinstance(workers_per_resource, float) and workers_per_resource < 1.0:
|
||||
raise ValueError("Fractional CPU multi threading support is not yet supported.")
|
||||
return workers_per_resource
|
||||
return int(workers_per_resource)
|
||||
|
||||
return math.ceil(cpus) * workers_per_resource
|
||||
|
||||
@@ -167,11 +187,13 @@ class CascadingResourceStrategy(Strategy, ReprMixin):
|
||||
workers_per_resource: int | float,
|
||||
worker_index: int,
|
||||
) -> dict[str, t.Any]:
|
||||
"""
|
||||
"""Get worker env for this given worker_index.
|
||||
|
||||
Args:
|
||||
runnable_class : The runnable class to be run.
|
||||
resource_request : The resource request of the runnable.
|
||||
worker_index : The index of the worker, start from 0.
|
||||
runnable_class: The runnable class to be run.
|
||||
resource_request: The resource request of the runnable.
|
||||
workers_per_resource: # of workers per resource.
|
||||
worker_index: The index of the worker, start from 0.
|
||||
"""
|
||||
cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
disabled = cuda_env in ("", "-1")
|
||||
|
||||
@@ -11,35 +11,43 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Types definition for OpenLLM.
|
||||
"""Types definition for OpenLLM.
|
||||
|
||||
Note that this module SHOULD NOT BE IMPORTED DURING RUNTIME, as this serve only for typing purposes.
|
||||
It will raises a RuntimeError if this is imported eagerly.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
|
||||
if not t.TYPE_CHECKING:
|
||||
raise RuntimeError(f"{__name__} should not be imported during runtime")
|
||||
|
||||
import click
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
import transformers
|
||||
|
||||
from ._configuration import AdapterType
|
||||
|
||||
from bentoml._internal.runner.runnable import RunnableMethod
|
||||
from bentoml._internal.runner.runner import RunnerMethod
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import click
|
||||
import peft
|
||||
|
||||
import openllm
|
||||
import transformers
|
||||
from bentoml._internal.runner.runnable import RunnableMethod
|
||||
from bentoml._internal.runner.runner import RunnerMethod
|
||||
|
||||
AnyCallable = t.Callable[..., t.Any]
|
||||
DictStrAny = dict[str, t.Any]
|
||||
ListAny = list[t.Any]
|
||||
ListStr = list[str]
|
||||
TupleAny = tuple[t.Any, ...]
|
||||
P = t.ParamSpec("P")
|
||||
O_co = t.TypeVar("O_co", covariant=True)
|
||||
LiteralRuntime: t.TypeAlias = t.Literal["pt", "tf", "flax"]
|
||||
T = t.TypeVar("T")
|
||||
Ts = t.TypeVarTuple("Ts")
|
||||
|
||||
|
||||
class ClickFunctionWrapper(t.Protocol[P, O_co]):
|
||||
@@ -83,9 +91,19 @@ class TokenizerProtocol(_StubsMixin[_MT], t.Protocol):
|
||||
...
|
||||
|
||||
|
||||
PeftAdapterOutput = dict[t.Literal["success", "result", "error_msg"], bool | str | dict[t.Any, t.Any]]
|
||||
class PeftAdapterOutput(t.TypedDict):
|
||||
success: bool
|
||||
result: dict[str, peft.PeftConfig]
|
||||
error_msg: str
|
||||
|
||||
AdaptersMapping = dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] | None
|
||||
|
||||
class AdaptersTuple(TupleAny):
|
||||
adapter_id: str
|
||||
name: str | None
|
||||
config: DictStrAny
|
||||
|
||||
|
||||
AdaptersMapping = dict[AdapterType, tuple[AdaptersTuple, ...]] | None
|
||||
|
||||
|
||||
class LLMRunnable(bentoml.Runnable):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,8 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
OpenLLM client.
|
||||
"""OpenLLM client.
|
||||
|
||||
To start interact with the server, you can do the following:
|
||||
|
||||
@@ -21,7 +20,6 @@ To start interact with the server, you can do the following:
|
||||
>>> client.query("What is the meaning of life?")
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import itertools
|
||||
import typing as t
|
||||
|
||||
@@ -11,9 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Base exceptions for OpenLLM. This extends BentoML exceptions.
|
||||
"""
|
||||
"""Base exceptions for OpenLLM. This extends BentoML exceptions."""
|
||||
from __future__ import annotations
|
||||
|
||||
import bentoml
|
||||
|
||||
@@ -15,12 +15,14 @@
|
||||
"""This module is derived from HuggingFace's AutoConfig, AutoModel, etc."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
from ...utils import is_torch_available, is_flax_available, is_tf_available, LazyModule
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_flax_available
|
||||
from ...utils import is_tf_available
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
@@ -24,6 +22,7 @@ import openllm
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import types
|
||||
from collections import _odict_items
|
||||
from collections import _odict_keys
|
||||
from collections import _odict_values
|
||||
@@ -93,9 +92,7 @@ class _LazyConfigMapping(ConfigOrderedDict):
|
||||
return item in self._mapping or item in self._extra_content
|
||||
|
||||
def register(self, key: str, value: t.Any):
|
||||
"""
|
||||
Register a new configuration in this mapping.
|
||||
"""
|
||||
"""Register a new configuration in this mapping."""
|
||||
if key in self._mapping.keys():
|
||||
raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.")
|
||||
self._extra_content[key] = value
|
||||
@@ -115,7 +112,10 @@ CONFIG_NAME_ALIASES: dict[str, str] = {
|
||||
|
||||
class AutoConfig:
|
||||
def __init__(self, *_: t.Any, **__: t.Any):
|
||||
raise EnvironmentError("Cannot instantiate Config. Please use `Config.for_model(model_name)` instead.")
|
||||
"""This metaclass should be initialised via `AutoConfig.for_model`."""
|
||||
raise EnvironmentError(
|
||||
"Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model_name: str, **attrs: t.Any) -> openllm.LLMConfig:
|
||||
|
||||
@@ -13,11 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import types
|
||||
import sys
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
@@ -30,13 +29,14 @@ from .configuration_auto import AutoConfig
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if hasattr(t, "get_overloads"):
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import types
|
||||
from collections import _odict_items
|
||||
from collections import _odict_keys
|
||||
from collections import _odict_values
|
||||
@@ -54,10 +54,10 @@ else:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _BaseAutoLLMClass:
|
||||
class BaseAutoLLMClass:
|
||||
_model_mapping: _LazyAutoMapping
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any): # noqa
|
||||
raise EnvironmentError(
|
||||
f"Cannot instantiate {self.__class__.__name__} directly. "
|
||||
"Please use '{self.__class__.__name__}.Runner(model_name)' instead."
|
||||
@@ -129,7 +129,7 @@ class _BaseAutoLLMClass:
|
||||
llm = model_class.from_pretrained(model_id, model_version=model_version, llm_config=llm_config, **attrs)
|
||||
if ensure_available:
|
||||
logger.debug(
|
||||
"'ensure_available=True', Downloading '%s' with 'model_id=%s' to local model store.",
|
||||
"'ensure_available=True', OpenLLM will automatically import '%s' with 'model_id=%s' to local store if the entry does not exists.",
|
||||
model,
|
||||
llm.model_id,
|
||||
)
|
||||
@@ -144,8 +144,7 @@ class _BaseAutoLLMClass:
|
||||
|
||||
@classmethod
|
||||
def create_runner(cls, model: str, model_id: str | None = None, **attrs: t.Any) -> LLMRunner:
|
||||
"""
|
||||
Create a LLM Runner for the given model name.
|
||||
"""Create a LLM Runner for the given model name.
|
||||
|
||||
Args:
|
||||
model: The model name to instantiate.
|
||||
@@ -160,8 +159,7 @@ class _BaseAutoLLMClass:
|
||||
|
||||
@classmethod
|
||||
def register(cls, config_class: type[openllm.LLMConfig], llm_class: type[openllm.LLM[t.Any, t.Any]]):
|
||||
"""
|
||||
Register a new model for this class.
|
||||
"""Register a new model for this class.
|
||||
|
||||
Args:
|
||||
config_class: The configuration corresponding to the model to register.
|
||||
@@ -191,13 +189,14 @@ def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any:
|
||||
try:
|
||||
return getattribute_from_module(openllm_module, attr)
|
||||
except ValueError:
|
||||
raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!")
|
||||
raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None
|
||||
else:
|
||||
raise ValueError(f"Could not find {attr} in {openllm_module}!")
|
||||
|
||||
|
||||
class _LazyAutoMapping(ConfigModelOrderedDict):
|
||||
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping
|
||||
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping.
|
||||
|
||||
This OrderedDict values() and keys() returns the list instead, so you don't
|
||||
have to do list(mapping.values()) to get the list of values.
|
||||
"""
|
||||
@@ -281,9 +280,7 @@ class _LazyAutoMapping(ConfigModelOrderedDict):
|
||||
return model_type in self._model_mapping
|
||||
|
||||
def register(self, key: t.Any, value: t.Any):
|
||||
"""
|
||||
Register a new model in this mapping.
|
||||
"""
|
||||
"""Register a new model in this mapping."""
|
||||
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
|
||||
model_type = self._reverse_config_mapping[key.__name__]
|
||||
if model_type in self._model_mapping.keys():
|
||||
@@ -292,4 +289,4 @@ class _LazyAutoMapping(ConfigModelOrderedDict):
|
||||
self._extra_content[key] = value
|
||||
|
||||
|
||||
__all__ = ["_BaseAutoLLMClass", "_LazyAutoMapping"]
|
||||
__all__ = ["BaseAutoLLMClass", "_LazyAutoMapping"]
|
||||
|
||||
@@ -13,12 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import _BaseAutoLLMClass
|
||||
from .factory import BaseAutoLLMClass
|
||||
from .factory import _LazyAutoMapping
|
||||
|
||||
|
||||
@@ -45,5 +44,5 @@ MODEL_MAPPING: dict[
|
||||
] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
|
||||
|
||||
class AutoLLM(_BaseAutoLLMClass):
|
||||
class AutoLLM(BaseAutoLLMClass):
|
||||
_model_mapping = MODEL_MAPPING
|
||||
|
||||
@@ -13,12 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import _BaseAutoLLMClass
|
||||
from .factory import BaseAutoLLMClass
|
||||
from .factory import _LazyAutoMapping
|
||||
|
||||
|
||||
@@ -38,5 +37,5 @@ MODEL_FLAX_MAPPING: dict[
|
||||
] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES)
|
||||
|
||||
|
||||
class AutoFlaxLLM(_BaseAutoLLMClass):
|
||||
class AutoFlaxLLM(BaseAutoLLMClass):
|
||||
_model_mapping = MODEL_FLAX_MAPPING
|
||||
|
||||
@@ -13,12 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import _BaseAutoLLMClass
|
||||
from .factory import BaseAutoLLMClass
|
||||
from .factory import _LazyAutoMapping
|
||||
|
||||
|
||||
@@ -38,5 +37,5 @@ MODEL_TF_MAPPING: dict[
|
||||
] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES)
|
||||
|
||||
|
||||
class AutoTFLLM(_BaseAutoLLMClass):
|
||||
class AutoTFLLM(BaseAutoLLMClass):
|
||||
_model_mapping = MODEL_TF_MAPPING
|
||||
|
||||
@@ -13,11 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ...utils import is_torch_available, is_cpm_kernels_available, LazyModule
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_cpm_kernels_available
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
|
||||
@@ -17,9 +17,7 @@ import openllm
|
||||
|
||||
|
||||
class ChatGLMConfig(openllm.LLMConfig):
|
||||
"""
|
||||
ChatGLM is an open bilingual language model based on
|
||||
[General Language Model (GLM)](https://github.com/THUDM/GLM) framework.
|
||||
"""ChatGLM is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework.
|
||||
|
||||
With the quantization technique, users can deploy locally on consumer-grade graphics cards
|
||||
(only 6GB of GPU memory is required at the INT4 quantization level).
|
||||
|
||||
@@ -12,15 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
from ...utils import generate_labels
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
else:
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
@@ -34,15 +36,15 @@ class ChatGLM(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrain
|
||||
self.device = torch.device("cuda")
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model:
|
||||
(_, model_attrs), tokenizer_kwds = self.llm_parameters
|
||||
attrs = {**model_attrs, **attrs}
|
||||
_, tokenizer_attrs = self.llm_parameters
|
||||
|
||||
return bentoml.transformers.save_model(
|
||||
self.tag,
|
||||
transformers.AutoModel.from_pretrained(self.model_id, trust_remote_code=trust_remote_code),
|
||||
labels=generate_labels(self),
|
||||
custom_objects={
|
||||
"tokenizer": transformers.AutoTokenizer.from_pretrained(
|
||||
self.model_id, trust_remote_code=trust_remote_code, **tokenizer_kwds
|
||||
self.model_id, trust_remote_code=trust_remote_code, **tokenizer_attrs
|
||||
)
|
||||
},
|
||||
)
|
||||
@@ -62,8 +64,8 @@ class ChatGLM(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrain
|
||||
|
||||
if use_default_prompt_template and chat_history is not None:
|
||||
for i, (old_query, response) in enumerate(chat_history):
|
||||
prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n"
|
||||
prompt_text += f"[Round {len(chat_history)}]\n问:{prompt}\n答:"
|
||||
prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n" # noqa: RUF001
|
||||
prompt_text += f"[Round {len(chat_history)}]\n问:{prompt}\n答:" # noqa: RUF001
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
|
||||
@@ -13,11 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ...utils import is_torch_available, LazyModule
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_dolly_v2": ["DollyV2Config", "START_DOLLY_V2_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"],
|
||||
|
||||
@@ -11,12 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The following includes OpenLLM configuration and excerpt from
|
||||
[instruct_pipeline.py](https://huggingface.co/databricks/dolly-v2-3b/blob/main/instruct_pipeline.py)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
@@ -27,8 +22,7 @@ if t.TYPE_CHECKING:
|
||||
|
||||
|
||||
class DollyV2Config(openllm.LLMConfig):
|
||||
"""Databricks’ Dolly is an instruction-following large language model trained on the Databricks
|
||||
machine learning platform that is licensed for commercial use.
|
||||
"""Databricks` Dolly is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use.
|
||||
|
||||
Based on pythia-12b, Dolly is trained on ~15k instruction/response fine tuning records databricks-dolly-15k
|
||||
generated by Databricks employees in capability domains from the InstructGPT paper, including brainstorming,
|
||||
@@ -103,15 +97,19 @@ DEFAULT_PROMPT_TEMPLATE = """{intro}
|
||||
|
||||
def get_special_token_id(tokenizer: PreTrainedTokenizer, key: str) -> int:
|
||||
"""Gets the token ID for a given string that has been added to the tokenizer as a special token.
|
||||
|
||||
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
|
||||
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
|
||||
|
||||
Args:
|
||||
tokenizer (PreTrainedTokenizer): the tokenizer
|
||||
key (str): the key to convert to a single token
|
||||
tokenizer: the tokenizer
|
||||
key: the key to convert to a single token
|
||||
|
||||
Raises:
|
||||
RuntimeError: if more than one ID was generated
|
||||
|
||||
Returns:
|
||||
int: the token ID for the given key
|
||||
int: the token ID for the given key.
|
||||
"""
|
||||
token_ids = tokenizer.encode(key)
|
||||
if len(token_ids) > 1:
|
||||
|
||||
@@ -12,24 +12,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
from ...utils import normalize_attrs_to_model_tokenizer_pair
|
||||
from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_dolly_v2 import END_KEY
|
||||
from .configuration_dolly_v2 import RESPONSE_KEY
|
||||
from .configuration_dolly_v2 import get_special_token_id
|
||||
from ...utils import normalize_attrs_to_model_tokenizer_pair
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
import bentoml
|
||||
import transformers
|
||||
else:
|
||||
tf = openllm.utils.LazyLoader("tf", globals(), "tensorflow")
|
||||
@@ -75,14 +75,16 @@ def get_pipeline(
|
||||
top_k: int = 0,
|
||||
**kwargs: t.Any,
|
||||
):
|
||||
"""Initialize the pipeline
|
||||
"""Initialize the pipeline.
|
||||
|
||||
Args:
|
||||
do_sample (bool, optional): Whether or not to use sampling. Defaults to True.
|
||||
max_new_tokens (int, optional): Max new tokens after the prompt to generate. Defaults to 128.
|
||||
top_p (float, optional): If set to float < 1, only the smallest set of most probable tokens with
|
||||
probabilities that add up to top_p or higher are kept for generation. Defaults to 0.92.
|
||||
top_k (int, optional): The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
Defaults to 0.
|
||||
do_sample: Whether or not to use sampling. Defaults to True.
|
||||
max_new_tokens: Max new tokens after the prompt to generate. Defaults to 128.
|
||||
top_p: If set to float < 1, only the smallest set of most probable tokens with
|
||||
probabilities that add up to top_p or higher are kept for generation. Defaults to 0.92.
|
||||
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to 0.
|
||||
*args: Additional positional arguments to be passed to ``transformers.Pipeline``.
|
||||
**kwargs: Additional keyword arguments to be passed to ``transformers.Pipeline``.
|
||||
"""
|
||||
super().__init__(
|
||||
*args,
|
||||
@@ -195,7 +197,7 @@ def get_pipeline(
|
||||
try:
|
||||
response_pos = sequence.index(response_key_token_id)
|
||||
except ValueError:
|
||||
logger.warn(f"Could not find response key {response_key_token_id} in: {sequence}")
|
||||
logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence)
|
||||
response_pos = None
|
||||
|
||||
if response_pos:
|
||||
@@ -228,7 +230,7 @@ def get_pipeline(
|
||||
if m:
|
||||
decoded = m.group(1).strip()
|
||||
else:
|
||||
logger.warn(f"Failed to find response in:\n{fully_decoded}")
|
||||
logger.warning("Failed to find response in:\n%s", fully_decoded)
|
||||
|
||||
# If the full text is requested, then append the decoded text to the original instruction.
|
||||
# This technically isn't the full text, as we format the instruction in the prompt the model has been
|
||||
|
||||
@@ -13,11 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ...utils import is_torch_available, LazyModule
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
|
||||
@@ -17,9 +17,9 @@ import openllm
|
||||
|
||||
|
||||
class FalconConfig(openllm.LLMConfig):
|
||||
"""Falcon-7B is a 7B parameters causal decoder-only model built by
|
||||
TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||
enhanced with curated corpora. It is made available under the TII Falcon LLM License.
|
||||
"""Falcon-7B is a 7B parameters causal decoder-only model built by TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) enhanced with curated corpora.
|
||||
|
||||
It is made available under the TII Falcon LLM License.
|
||||
|
||||
Refer to [Falcon's HuggingFace page](https://huggingface.co/tiiuae/falcon-7b) for more information.
|
||||
"""
|
||||
|
||||
@@ -13,19 +13,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
from ..._prompt import default_formatter
|
||||
from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import default_formatter
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
import torch.amp
|
||||
|
||||
import bentoml
|
||||
import transformers
|
||||
else:
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
@@ -81,7 +81,7 @@ class Falcon(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTraine
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
|
||||
"Use 'use_default_prompt_template=False' to disable the default prompt template."
|
||||
)
|
||||
) from None
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
|
||||
@@ -13,11 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ...utils import is_torch_available, is_tf_available, is_flax_available, LazyModule
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_flax_available
|
||||
from ...utils import is_tf_available
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
|
||||
@@ -46,8 +46,9 @@ DEFAULT_PROMPT_TEMPLATE = """Answer the following question:\nQuestion: {instruct
|
||||
|
||||
|
||||
class FlanT5Config(openllm.LLMConfig):
|
||||
"""FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf)
|
||||
- it is an enhanced version of T5 that has been finetuned in a mixture of tasks.
|
||||
"""FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf).
|
||||
|
||||
It is an enhanced version of T5 that has been finetuned in a mixture of tasks.
|
||||
|
||||
Refer to [FLAN-T5's page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for more information.
|
||||
"""
|
||||
|
||||
@@ -12,19 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
from ..._prompt import default_formatter
|
||||
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import default_formatter
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
import transformers # noqa
|
||||
import transformers # noqa: F401
|
||||
else:
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
|
||||
@@ -60,7 +59,7 @@ class FlanT5(openllm.LLM["transformers.T5ForConditionalGeneration", "transformer
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
|
||||
"Use 'use_default_prompt_template=False' to disable the default prompt template."
|
||||
)
|
||||
) from None
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
|
||||
@@ -12,17 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
from ..._prompt import default_formatter
|
||||
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import default_formatter
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers # noqa
|
||||
import transformers # noqa: F401
|
||||
|
||||
|
||||
class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
|
||||
@@ -54,7 +53,7 @@ class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "tra
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
|
||||
"Use 'use_default_prompt_template=False' to disable the default prompt template."
|
||||
)
|
||||
) from None
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
|
||||
@@ -12,17 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
from ..._prompt import default_formatter
|
||||
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import default_formatter
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers # noqa
|
||||
import transformers # noqa: F401
|
||||
|
||||
|
||||
class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
|
||||
@@ -40,17 +39,20 @@ class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transfo
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if use_default_prompt_template:
|
||||
prompt_variables = {
|
||||
k: v
|
||||
for k, v in attrs.items()
|
||||
if k in default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE)
|
||||
}
|
||||
template_variables = default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE)
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
|
||||
if "instruction" in prompt_variables:
|
||||
raise RuntimeError(
|
||||
"'instruction' should be passed as the first argument "
|
||||
"instead of kwargs when 'use_default_prompt_template=True'"
|
||||
)
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables)
|
||||
try:
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
|
||||
"Use 'use_default_prompt_template=False' to disable the default prompt template."
|
||||
) from None
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
|
||||
@@ -13,11 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ...utils import is_torch_available, LazyModule
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
|
||||
@@ -14,14 +14,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import openllm
|
||||
|
||||
|
||||
class GPTNeoXConfig(openllm.LLMConfig):
|
||||
"""GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and
|
||||
openly available to the public through a permissive license. It is, to the best of our knowledge, the largest dense autoregressive model
|
||||
"""GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and openly available to the public through a permissive license.
|
||||
|
||||
It is, to the best of our knowledge, the largest dense autoregressive model
|
||||
that has publicly available weights at the time of submission. The training and evaluation code, as well as the model weights,
|
||||
can be found at https://github.com/EleutherAI/gpt-neox.
|
||||
|
||||
|
||||
@@ -13,21 +13,21 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
from ..._prompt import default_formatter
|
||||
from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import default_formatter
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
import transformers # noqa
|
||||
import torch
|
||||
import torch.amp
|
||||
|
||||
import bentoml
|
||||
import transformers
|
||||
else:
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
@@ -62,7 +62,7 @@ class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNe
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
|
||||
"Use 'use_default_prompt_template=False' to disable the default prompt template."
|
||||
)
|
||||
) from None
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
|
||||
@@ -13,11 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ...utils import is_torch_available, LazyModule
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
@@ -27,10 +26,11 @@ else:
|
||||
|
||||
|
||||
class MPTConfig(openllm.LLMConfig):
|
||||
"""MPT is a decoder-style transformer pretrained from scratch on
|
||||
English text and code. This model was trained by [MosaicML](https://www.mosaicml.com/).
|
||||
"""MPT is a decoder-style transformer pretrained from scratch on English text and code.
|
||||
|
||||
`openllm.MPT` encapsulate a family of MPT variants that is publicly available
|
||||
This model was trained by [MosaicML](https://www.mosaicml.com/).
|
||||
|
||||
``openllm.MPT`` encapsulate a family of MPT variants that is publicly available
|
||||
on HuggingFace. Refers [HuggingFace's MosaicML page](https://huggingface.co/mosaicml)
|
||||
for more details on specific models.
|
||||
"""
|
||||
|
||||
@@ -13,16 +13,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
from ..._prompt import default_formatter
|
||||
from ...utils import is_triton_available
|
||||
from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import default_formatter
|
||||
from ...utils import generate_labels
|
||||
from ...utils import is_triton_available
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -78,8 +78,7 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken
|
||||
return model_kwds, tokenizer_kwds
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model:
|
||||
(_, model_attrs), tokenizer_kwds = self.llm_parameters
|
||||
attrs = {**model_attrs, **attrs}
|
||||
_, tokenizer_attrs = self.llm_parameters
|
||||
|
||||
torch_dtype = attrs.pop("torch_dtype", self.dtype)
|
||||
device_map = attrs.pop("device_map", None)
|
||||
@@ -93,7 +92,7 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
|
||||
if tokenizer.pad_token_id is None:
|
||||
logger.warning("pad_token_id is not set. Setting it to eos_token")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
@@ -107,7 +106,12 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken
|
||||
**attrs,
|
||||
)
|
||||
try:
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer})
|
||||
return bentoml.transformers.save_model(
|
||||
self.tag,
|
||||
model,
|
||||
custom_objects={"tokenizer": tokenizer},
|
||||
labels=generate_labels(self),
|
||||
)
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -169,7 +173,7 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
|
||||
"Use 'use_default_prompt_template=False' to disable the default prompt template."
|
||||
)
|
||||
) from None
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
|
||||
@@ -13,11 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ...utils import is_torch_available, LazyModule, is_flax_available, is_tf_available
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_flax_available
|
||||
from ...utils import is_tf_available
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
|
||||
@@ -18,9 +18,7 @@ import openllm
|
||||
|
||||
|
||||
class OPTConfig(openllm.LLMConfig):
|
||||
"""OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068)
|
||||
and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq)
|
||||
on May 3rd 2022 by Meta AI.
|
||||
"""OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq) on May 3rd 2022 by Meta AI.
|
||||
|
||||
OPT was predominantly pretrained with English text, but a small amount of non-English data is still present
|
||||
within the training corpus via CommonCrawl. The model was pretrained using a causal language modeling (CLM)
|
||||
|
||||
@@ -13,15 +13,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
from ..._prompt import default_formatter
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import default_formatter
|
||||
from ...utils import generate_labels
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -44,17 +44,21 @@ class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tok
|
||||
return {}, tokenizer_kwds
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
(_, model_attrs), tokenizer_kwds = self.llm_parameters
|
||||
attrs = {**model_attrs, **attrs}
|
||||
_, tokenizer_attrs = self.llm_parameters
|
||||
|
||||
config = transformers.AutoConfig.from_pretrained(self.model_id)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
model = t.cast(
|
||||
"transformers.FlaxOPTForCausalLM",
|
||||
transformers.FlaxAutoModelForCausalLM.from_pretrained(self.model_id, **attrs),
|
||||
)
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer})
|
||||
return bentoml.transformers.save_model(
|
||||
self.tag,
|
||||
model,
|
||||
custom_objects={"tokenizer": tokenizer},
|
||||
labels=generate_labels(self),
|
||||
)
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
@@ -81,7 +85,7 @@ class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tok
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
|
||||
"Use 'use_default_prompt_template=False' to disable the default prompt template."
|
||||
)
|
||||
) from None
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
|
||||
@@ -13,21 +13,21 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
from ..._prompt import default_formatter
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import default_formatter
|
||||
from ...utils import generate_labels
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
import transformers # noqa
|
||||
import transformers
|
||||
else:
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
@@ -55,13 +55,12 @@ class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer
|
||||
return model_kwds, tokenizer_kwds
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
(_, model_attrs), tokenizer_kwds = self.llm_parameters
|
||||
attrs = {**model_attrs, **attrs}
|
||||
_, tokenizer_attrs = self.llm_parameters
|
||||
|
||||
torch_dtype = attrs.pop("torch_dtype", self.dtype)
|
||||
|
||||
config = transformers.AutoConfig.from_pretrained(self.model_id)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
model = t.cast(
|
||||
"transformers.OPTForCausalLM",
|
||||
@@ -69,7 +68,12 @@ class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer
|
||||
self.model_id, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, **attrs
|
||||
),
|
||||
)
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer})
|
||||
return bentoml.transformers.save_model(
|
||||
self.tag,
|
||||
model,
|
||||
custom_objects={"tokenizer": tokenizer},
|
||||
labels=generate_labels(self),
|
||||
)
|
||||
|
||||
def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> transformers.OPTForCausalLM:
|
||||
torch_dtype = attrs.pop("torch_dtype", self.dtype)
|
||||
@@ -105,7 +109,7 @@ class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
|
||||
"Use 'use_default_prompt_template=False' to disable the default prompt template."
|
||||
)
|
||||
) from None
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
|
||||
@@ -13,15 +13,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
from ..._prompt import default_formatter
|
||||
from .configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
from ..._prompt import default_formatter
|
||||
from ...utils import generate_labels
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -45,16 +45,20 @@ class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Token
|
||||
return {}, tokenizer_kwds
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
(_, model_attrs), tokenizer_kwds = self.llm_parameters
|
||||
attrs = {**model_attrs, **attrs}
|
||||
_, tokenizer_attrs = self.llm_parameters
|
||||
|
||||
config = transformers.AutoConfig.from_pretrained(self.model_id)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
model: transformers.TFOPTForCausalLM = transformers.TFOPTForCausalLM.from_pretrained(
|
||||
self.model_id, trust_remote_code=trust_remote_code, **attrs
|
||||
)
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer})
|
||||
return bentoml.transformers.save_model(
|
||||
self.tag,
|
||||
model,
|
||||
custom_objects={"tokenizer": tokenizer},
|
||||
labels=generate_labels(self),
|
||||
)
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
@@ -80,7 +84,7 @@ class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Token
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. "
|
||||
"Use 'use_default_prompt_template=False' to disable the default prompt template."
|
||||
)
|
||||
) from None
|
||||
else:
|
||||
prompt_text = prompt
|
||||
|
||||
|
||||
@@ -13,11 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ...utils import is_torch_available, LazyModule
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
|
||||
@@ -17,8 +17,9 @@ import openllm
|
||||
|
||||
|
||||
class StableLMConfig(openllm.LLMConfig):
|
||||
"""StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models
|
||||
pre-trained on a diverse collection of English datasets with a sequence
|
||||
"""StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models.
|
||||
|
||||
It is pre-trained on a diverse collection of English datasets with a sequence
|
||||
length of 4096 to push beyond the context window limitations of existing open-source language models.
|
||||
|
||||
StableLM-Tuned-Alpha is a suite of 3B and 7B parameter decoder-only language models
|
||||
@@ -74,6 +75,6 @@ SYSTEM_PROMPT = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
""" # noqa
|
||||
"""
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>"""
|
||||
|
||||
@@ -12,15 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
from ..._prompt import default_formatter
|
||||
from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE
|
||||
from .configuration_stablelm import SYSTEM_PROMPT
|
||||
from ..._prompt import default_formatter
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
@@ -13,11 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ...utils import is_torch_available, LazyModule
|
||||
from ...exceptions import MissingDependencyError
|
||||
from ...utils import LazyModule
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
|
||||
@@ -17,8 +17,7 @@ import openllm
|
||||
|
||||
|
||||
class StarCoderConfig(openllm.LLMConfig):
|
||||
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from
|
||||
[The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.
|
||||
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.
|
||||
|
||||
The model uses [Multi Query Attention](https://arxiv.org/abs/1911.02150),
|
||||
[a context window of 8192 tokens](https://arxiv.org/abs/2205.14135), and was trained using the
|
||||
|
||||
@@ -12,13 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
from ...utils import generate_labels
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
@@ -54,13 +55,12 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.
|
||||
return model_kwds, tokenizer_kwds
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
(_, model_attrs), tokenizer_kwds = self.llm_parameters
|
||||
attrs = {**model_attrs, **attrs}
|
||||
_, tokenizer_attrs = self.llm_parameters
|
||||
|
||||
torch_dtype = attrs.pop("torch_dtype", torch.float16)
|
||||
device_map = attrs.pop("device_map", "auto")
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD],
|
||||
@@ -72,7 +72,12 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.
|
||||
self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs
|
||||
)
|
||||
try:
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer})
|
||||
return bentoml.transformers.save_model(
|
||||
self.tag,
|
||||
model,
|
||||
custom_objects={"tokenizer": tokenizer},
|
||||
labels=generate_labels(self),
|
||||
)
|
||||
finally:
|
||||
# NOTE: We need to free the cache after saving here so that we can load it back later on.
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
|
||||
@@ -37,15 +37,20 @@ llm.save_pretrained("./path/to/local-dolly")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
from ..utils import LazyModule
|
||||
import typing as t
|
||||
import openllm
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
from .._types import ModelProtocol, TokenizerProtocol
|
||||
from .transformers import _M, _T
|
||||
|
||||
from .._llm import M
|
||||
from .._llm import T
|
||||
from .._types import ModelProtocol
|
||||
from .._types import TokenizerProtocol
|
||||
|
||||
|
||||
def import_model(
|
||||
@@ -80,7 +85,7 @@ def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs
|
||||
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
|
||||
|
||||
|
||||
def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[_M]:
|
||||
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[M]:
|
||||
if llm.runtime == "transformers":
|
||||
return openllm.transformers.load_model(llm, *decls, **attrs)
|
||||
elif llm.runtime == "ggml":
|
||||
@@ -89,7 +94,7 @@ def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> Mo
|
||||
raise ValueError(f"Unknown runtime: {llm.config['runtime']}")
|
||||
|
||||
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]:
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, T]) -> TokenizerProtocol[T]:
|
||||
if llm.runtime == "transformers":
|
||||
return openllm.transformers.load_tokenizer(llm)
|
||||
elif llm.runtime == "ggml":
|
||||
@@ -109,11 +114,6 @@ _extras = {
|
||||
_import_structure: dict[str, list[str]] = {"ggml": [], "transformers": []}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from . import import_model as import_model
|
||||
from . import get as get
|
||||
from . import save_pretrained as save_pretrained
|
||||
from . import load_model as load_model
|
||||
from . import load_tokenizer as load_tokenizer
|
||||
from . import ggml as ggml
|
||||
from . import transformers as transformers
|
||||
else:
|
||||
|
||||
@@ -16,19 +16,25 @@
|
||||
This requires ctransformers to be installed.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import openllm
|
||||
import typing as t
|
||||
import bentoml
|
||||
|
||||
import cloudpickle
|
||||
from ..exceptions import OpenLLMException
|
||||
from ..utils import LazyLoader
|
||||
|
||||
import bentoml
|
||||
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
||||
|
||||
from ..exceptions import OpenLLMException
|
||||
from ..utils import LazyLoader
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .._types import ModelProtocol, TokenizerProtocol
|
||||
from .transformers import _M, _T
|
||||
import openllm
|
||||
import transformers
|
||||
|
||||
from .._llm import M
|
||||
from .._llm import T
|
||||
from .._types import ModelProtocol
|
||||
from .._types import TokenizerProtocol
|
||||
else:
|
||||
transformers = LazyLoader("transformers", globals(), "transformers")
|
||||
|
||||
@@ -44,6 +50,7 @@ def import_model(
|
||||
|
||||
def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model:
|
||||
"""Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
|
||||
By default, it will try to check the model in the local store.
|
||||
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
|
||||
|
||||
@@ -66,15 +73,16 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo
|
||||
raise
|
||||
|
||||
|
||||
def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[_M]:
|
||||
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[M]:
|
||||
"""Load the model from BentoML store.
|
||||
|
||||
By default, it will try to find check the model in the local store.
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
|
||||
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]:
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, T]) -> TokenizerProtocol[T]:
|
||||
"""Load the tokenizer from BentoML store.
|
||||
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
@@ -95,14 +103,14 @@ def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]:
|
||||
"Model does not have tokenizer. Make sure to save \
|
||||
the tokenizer within the model via 'custom_objects'.\
|
||||
For example: bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer}))"
|
||||
)
|
||||
) from None
|
||||
else:
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
bentomodel_fs.getsyspath("/"),
|
||||
trust_remote_code=llm.__llm_trust_remote_code__,
|
||||
**tokenizer_attrs,
|
||||
)
|
||||
return t.cast("TokenizerProtocol[_T]", tokenizer)
|
||||
return t.cast("TokenizerProtocol[T]", tokenizer)
|
||||
|
||||
|
||||
def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any):
|
||||
|
||||
@@ -11,32 +11,43 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Serialisation related implementation for Transformers-based implementation.
|
||||
"""
|
||||
"""Serialisation related implementation for Transformers-based implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import openllm
|
||||
import typing as t
|
||||
import importlib
|
||||
import typing as t
|
||||
|
||||
import cloudpickle
|
||||
|
||||
import bentoml
|
||||
from bentoml._internal.frameworks.transformers import make_default_signatures
|
||||
from bentoml._internal.models.model import ModelOptions
|
||||
from ..exceptions import OpenLLMException
|
||||
import cloudpickle
|
||||
from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME
|
||||
from ..utils import LazyLoader, is_torch_available
|
||||
from ..utils import generate_context, normalize_attrs_to_model_tokenizer_pair
|
||||
from .constants import FRAMEWORK_TO_AUTOCLASS_MAPPING, MODEL_TO_AUTOCLASS_MAPPING
|
||||
from bentoml._internal.models.model import ModelOptions
|
||||
|
||||
from .constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
|
||||
from .constants import MODEL_TO_AUTOCLASS_MAPPING
|
||||
from ..exceptions import OpenLLMException
|
||||
from ..utils import LazyLoader
|
||||
from ..utils import first_not_none
|
||||
from ..utils import generate_context
|
||||
from ..utils import generate_labels
|
||||
from ..utils import is_torch_available
|
||||
from ..utils import normalize_attrs_to_model_tokenizer_pair
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
import torch
|
||||
from .._types import P
|
||||
from .._llm import _M, _T
|
||||
from .._types import DictStrAny, ModelProtocol, TokenizerProtocol
|
||||
|
||||
import openllm
|
||||
import transformers
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from .._llm import M
|
||||
from .._llm import T
|
||||
from .._types import DictStrAny
|
||||
from .._types import ModelProtocol
|
||||
from .._types import TokenizerProtocol
|
||||
else:
|
||||
transformers = LazyLoader("transformers", globals(), "transformers")
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
@@ -46,7 +57,6 @@ def process_transformers_config(
|
||||
model_id: str, trust_remote_code: bool, **attrs: t.Any
|
||||
) -> tuple[transformers.PretrainedConfig, dict[str, t.Any], dict[str, t.Any]]:
|
||||
"""Process transformers config and return PretrainedConfig with hub_kwargs and the rest of kwargs."""
|
||||
|
||||
config: transformers.PretrainedConfig = attrs.pop("config", None)
|
||||
|
||||
# this logic below is synonymous to handling `from_pretrained` attrs.
|
||||
@@ -123,7 +133,7 @@ def import_model(
|
||||
attrs = {**model_attrs, **attrs}
|
||||
|
||||
tokenizer = t.cast(
|
||||
transformers.PreTrainedTokenizer,
|
||||
"transformers.PreTrainedTokenizer",
|
||||
transformers.AutoTokenizer.from_pretrained(
|
||||
llm.model_id,
|
||||
config=config,
|
||||
@@ -133,13 +143,16 @@ def import_model(
|
||||
),
|
||||
)
|
||||
|
||||
model = infer_autoclass_from_llm_config(llm, config).from_pretrained(
|
||||
llm.model_id,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
model = t.cast(
|
||||
"transformers.PreTrainedModel",
|
||||
infer_autoclass_from_llm_config(llm, config).from_pretrained(
|
||||
llm.model_id,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -148,7 +161,7 @@ def import_model(
|
||||
module="openllm.serialisation.transformers",
|
||||
api_version="v1",
|
||||
context=generate_context(framework_name="openllm"),
|
||||
labels={"runtime": llm.runtime},
|
||||
labels=generate_labels(llm),
|
||||
options=ModelOptions(),
|
||||
signatures=make_default_signatures(model),
|
||||
external_modules=[
|
||||
@@ -173,6 +186,7 @@ def import_model(
|
||||
|
||||
def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model:
|
||||
"""Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
|
||||
By default, it will try to check the model in the local store.
|
||||
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
|
||||
|
||||
@@ -201,8 +215,9 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo
|
||||
raise
|
||||
|
||||
|
||||
def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[_M]:
|
||||
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[M]:
|
||||
"""Load the model from BentoML store.
|
||||
|
||||
By default, it will try to find check the model in the local store.
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
@@ -224,11 +239,11 @@ def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> Mo
|
||||
# BetterTransformer is currently only supported on PyTorch.
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
|
||||
model = BetterTransformer.transform(model)
|
||||
return t.cast("ModelProtocol[_M]", model)
|
||||
model = BetterTransformer.transform(model) # type: ignore
|
||||
return t.cast("ModelProtocol[M]", model)
|
||||
|
||||
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]:
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, T]) -> TokenizerProtocol[T]:
|
||||
"""Load the tokenizer from BentoML store.
|
||||
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
@@ -249,14 +264,14 @@ def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]:
|
||||
"Model does not have tokenizer. Make sure to save \
|
||||
the tokenizer within the model via 'custom_objects'.\
|
||||
For example: bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer}))"
|
||||
)
|
||||
) from None
|
||||
else:
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
bentomodel_fs.getsyspath("/"),
|
||||
trust_remote_code=llm.__llm_trust_remote_code__,
|
||||
**tokenizer_attrs,
|
||||
)
|
||||
return t.cast("TokenizerProtocol[_T]", tokenizer)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def save_pretrained(
|
||||
@@ -264,7 +279,7 @@ def save_pretrained(
|
||||
save_directory: str,
|
||||
is_main_process: bool = True,
|
||||
state_dict: DictStrAny | None = None,
|
||||
save_function: t.Callable[P, None] | None = None,
|
||||
save_function: t.Callable[..., None] | None = None,
|
||||
push_to_hub: bool = False,
|
||||
max_shard_size: int | str = "10GB",
|
||||
safe_serialization: bool = False,
|
||||
@@ -272,8 +287,7 @@ def save_pretrained(
|
||||
**attrs: t.Any,
|
||||
):
|
||||
"""Light wrapper around ``transformers.PreTrainedTokenizer.save_pretrained`` and ``transformers.PreTrainedModel.save_pretrained``."""
|
||||
if save_function is None:
|
||||
save_function = torch.save
|
||||
save_function = first_not_none(save_function, default=torch.save)
|
||||
|
||||
model_save_attrs, tokenizer_save_attrs = normalize_attrs_to_model_tokenizer_pair(**attrs)
|
||||
|
||||
|
||||
113
src/openllm/testing.py
Normal file
113
src/openllm/testing.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests utilities for OpenLLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import typing as t
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ._types import LiteralRuntime
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_bento(
|
||||
model: str,
|
||||
model_id: str | None = None,
|
||||
quantize: t.Literal["int4", "int8", "gptq"] | None = None,
|
||||
runtime: t.Literal["ggml", "transformers"] = "transformers",
|
||||
cleanup: bool = False,
|
||||
):
|
||||
logger.info("Building BentoML for %s", model)
|
||||
bento = openllm.build(model, model_id=model_id, quantize=quantize, runtime=runtime)
|
||||
yield bento
|
||||
if cleanup:
|
||||
logger.info("Deleting %s", bento.tag)
|
||||
bentoml.bentos.delete(bento.tag)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_container(
|
||||
bento: bentoml.Bento | str | bentoml.Tag,
|
||||
image_tag: str | None = None,
|
||||
cleanup: bool = False,
|
||||
**attrs: t.Any,
|
||||
):
|
||||
if isinstance(bento, bentoml.Bento):
|
||||
bento_tag = bento.tag
|
||||
else:
|
||||
bento_tag = bentoml.Tag.from_taglike(bento)
|
||||
|
||||
if image_tag is None:
|
||||
image_tag = str(bento_tag)
|
||||
|
||||
executable = shutil.which("docker")
|
||||
if not executable:
|
||||
raise RuntimeError("docker executable not found")
|
||||
|
||||
try:
|
||||
logger.info("Building container for %s", bento_tag)
|
||||
bentoml.container.build(
|
||||
bento_tag,
|
||||
backend="docker",
|
||||
image_tag=(image_tag,),
|
||||
progress="plain",
|
||||
**attrs,
|
||||
)
|
||||
yield image_tag
|
||||
finally:
|
||||
if cleanup:
|
||||
logger.info("Deleting container %s", image_tag)
|
||||
subprocess.check_output([executable, "rmi", "-f", image_tag])
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def prepare(
|
||||
model: str,
|
||||
model_id: str | None = None,
|
||||
implementation: LiteralRuntime = "pt",
|
||||
deployment_mode: t.Literal["container", "local"] = "local",
|
||||
clean_context: contextlib.ExitStack | None = None,
|
||||
cleanup: bool = False,
|
||||
):
|
||||
if clean_context is None:
|
||||
clean_context = contextlib.ExitStack()
|
||||
cleanup = True
|
||||
|
||||
llm = openllm.infer_auto_class(implementation).for_model(model, model_id=model_id, ensure_available=True)
|
||||
bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}")
|
||||
|
||||
if not bentoml.list(bento_tag):
|
||||
bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup))
|
||||
else:
|
||||
bento = bentoml.get(bento_tag)
|
||||
|
||||
container_name = f"openllm-{model}-{llm.llm_type}".replace("-", "_")
|
||||
|
||||
if deployment_mode == "container":
|
||||
container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup))
|
||||
|
||||
yield container_name
|
||||
if cleanup:
|
||||
clean_context.close()
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from .utils import is_flax_available
|
||||
from .utils import is_tf_available
|
||||
from .utils import is_torch_available
|
||||
|
||||
|
||||
try:
|
||||
import pytest
|
||||
except ImportError:
|
||||
raise ImportError("You need to install pytest to use 'openllm.tests' utilities: 'pip install pytest'")
|
||||
|
||||
|
||||
def require_tf(f: t.Callable[..., t.Any]):
|
||||
return pytest.mark.skipif(not is_tf_available(), reason="requires TensorFlow")(f)
|
||||
|
||||
|
||||
def require_flax(f: t.Callable[..., t.Any]):
|
||||
return pytest.mark.skipif(not is_flax_available(), reason="requires Flax")(f)
|
||||
|
||||
|
||||
def require_torch(f: t.Callable[..., t.Any]):
|
||||
return pytest.mark.skipif(not is_torch_available(), reason="requires PyTorch")(f)
|
||||
@@ -11,30 +11,35 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utilities function for OpenLLM. User can import these function for convenience, but
|
||||
"""Utilities function for OpenLLM.
|
||||
|
||||
User can import these function for convenience, but
|
||||
we won't ensure backward compatibility for these functions. So use with caution.
|
||||
"""
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import logging.config
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
import types
|
||||
import typing as t
|
||||
from pathlib import Path
|
||||
|
||||
from circus.exc import ConflictError
|
||||
|
||||
from bentoml._internal.configuration import DEBUG_ENV_VAR as _DEBUG_ENV_VAR
|
||||
from bentoml._internal.configuration import GRPC_DEBUG_ENV_VAR as _GRPC_DEBUG_ENV_VAR
|
||||
from bentoml._internal.configuration import get_debug_mode
|
||||
from bentoml._internal.configuration import get_quiet_mode
|
||||
from bentoml._internal.configuration import set_debug_mode
|
||||
from bentoml._internal.configuration import set_quiet_mode
|
||||
from bentoml._internal.log import configure_server_logging
|
||||
from bentoml._internal.models.model import ModelContext as _ModelContext
|
||||
from bentoml._internal.log import CLI_LOGGING_CONFIG as _CLI_LOGGING_CONFIG
|
||||
from bentoml._internal.types import LazyType
|
||||
from bentoml._internal.utils import LazyLoader
|
||||
from bentoml._internal.utils import bentoml_cattr
|
||||
from bentoml._internal.utils import cached_contextmanager
|
||||
from bentoml._internal.utils import copy_file_to_fs_folder
|
||||
from bentoml._internal.utils import first_not_none
|
||||
from bentoml._internal.utils import pkg
|
||||
@@ -62,8 +67,26 @@ else:
|
||||
types.UnionType,
|
||||
)
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import overload as _overload
|
||||
else:
|
||||
from typing_extensions import overload as _overload
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
|
||||
from .._types import DictStrAny
|
||||
from .._types import LiteralRuntime
|
||||
from .._types import P
|
||||
from ..models.auto.factory import BaseAutoLLMClass
|
||||
|
||||
|
||||
def set_debug_mode(enabled: bool):
|
||||
# monkeypatch bentoml._internal.configuration.set_debug_mode to remove unused logs
|
||||
os.environ[_DEBUG_ENV_VAR] = str(enabled)
|
||||
os.environ[_GRPC_DEBUG_ENV_VAR] = "DEBUG" if enabled else "ERROR"
|
||||
|
||||
|
||||
def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.Any], ...] | None) -> bool:
|
||||
@@ -75,16 +98,12 @@ def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.An
|
||||
raise
|
||||
|
||||
|
||||
def gpu_count() -> tuple[int, ...]:
|
||||
def gpu_count() -> tuple[str, ...]:
|
||||
from bentoml._internal.resource import NvidiaGpuResource
|
||||
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
if cuda_visible_devices is not None:
|
||||
if "," in cuda_visible_devices:
|
||||
available_gpu = tuple(int(i) for i in cuda_visible_devices.split(","))
|
||||
else:
|
||||
available_gpu = tuple(int(i) for i in cuda_visible_devices.split())
|
||||
return available_gpu
|
||||
return tuple(i for i in cuda_visible_devices.split(","))
|
||||
|
||||
return tuple(NvidiaGpuResource.from_system())
|
||||
|
||||
@@ -94,7 +113,7 @@ _object_setattr = object.__setattr__
|
||||
|
||||
|
||||
def non_intrusive_setattr(obj: t.Any, name: str, value: t.Any) -> None:
|
||||
"""This makes sure that we don't overwrite any existing attributes on the object"""
|
||||
"""This makes sure that we don't overwrite any existing attributes on the object."""
|
||||
_setattr = functools.partial(setattr, obj) if isinstance(obj, type) else _object_setattr.__get__(obj)
|
||||
|
||||
if not hasattr(obj, name):
|
||||
@@ -107,22 +126,65 @@ def field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None
|
||||
|
||||
DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.environ.get("OPENLLMDEVDEBUG")))
|
||||
|
||||
SHOW_CODEGEN = DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3
|
||||
|
||||
_LOGGING_CONFIG = _CLI_LOGGING_CONFIG.copy()
|
||||
_LOGGING_CONFIG["loggers"].update(
|
||||
{
|
||||
"openllm": {
|
||||
"level": logging.INFO,
|
||||
|
||||
class _ExceptionFilter(logging.Filter):
|
||||
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
|
||||
if exclude_exceptions is None:
|
||||
exclude_exceptions = [ConflictError]
|
||||
else:
|
||||
exclude_exceptions.append(ConflictError)
|
||||
|
||||
super(_ExceptionFilter, self).__init__(**kwargs)
|
||||
self.EXCLUDE_EXCEPTIONS = exclude_exceptions
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.exc_info:
|
||||
etype, _, _ = record.exc_info
|
||||
if etype is not None:
|
||||
for exc in self.EXCLUDE_EXCEPTIONS:
|
||||
if issubclass(etype, exc):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
_LOGGING_CONFIG: DictStrAny = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": True,
|
||||
"filters": {"excfilter": {"()": _ExceptionFilter}},
|
||||
"handlers": {
|
||||
"bentomlhandler": {
|
||||
"class": "logging.StreamHandler",
|
||||
"filters": ["excfilter"],
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
"defaulthandler": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": logging.WARNING,
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"bentoml": {
|
||||
"handlers": ["bentomlhandler", "defaulthandler"],
|
||||
"level": logging.INFO,
|
||||
"propagate": False,
|
||||
}
|
||||
}
|
||||
)
|
||||
},
|
||||
"openllm": {
|
||||
"handlers": ["bentomlhandler", "defaulthandler"],
|
||||
"level": logging.INFO,
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
"root": {"level": logging.WARNING},
|
||||
}
|
||||
|
||||
|
||||
def configure_logging() -> None:
|
||||
"""Configure logging for OpenLLM. Behaves similar to how BentoML loggers
|
||||
are being configured."""
|
||||
"""Configure logging for OpenLLM.
|
||||
|
||||
Behaves similar to how BentoML loggers are being configured.
|
||||
"""
|
||||
if get_quiet_mode():
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR
|
||||
@@ -144,7 +206,7 @@ def in_notebook() -> bool:
|
||||
try:
|
||||
from IPython.core.getipython import get_ipython
|
||||
|
||||
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
|
||||
if "IPKernelApp" not in get_ipython().config: # type: ignore
|
||||
return False
|
||||
except ImportError:
|
||||
return False
|
||||
@@ -153,8 +215,91 @@ def in_notebook() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
_dockerenv = Path("/.dockerenv")
|
||||
_cgroup = Path("/proc/self/cgroup")
|
||||
|
||||
|
||||
class suppress(contextlib.suppress, contextlib.ContextDecorator):
|
||||
"""A version of contextlib.suppress with decorator support.
|
||||
|
||||
>>> @suppress(KeyError)
|
||||
... def key_error():
|
||||
... {}['']
|
||||
>>> key_error()
|
||||
"""
|
||||
|
||||
|
||||
def compose(*funcs: t.Callable[..., t.Any]):
|
||||
"""Compose any number of unary functions into a single unary function.
|
||||
|
||||
>>> import textwrap
|
||||
>>> expected = str.strip(textwrap.dedent(compose.__doc__))
|
||||
>>> strip_and_dedent = compose(str.strip, textwrap.dedent)
|
||||
>>> strip_and_dedent(compose.__doc__) == expected
|
||||
True
|
||||
|
||||
Compose also allows the innermost function to take arbitrary arguments.
|
||||
|
||||
>>> round_three = lambda x: round(x, ndigits=3)
|
||||
>>> f = compose(round_three, int.__truediv__)
|
||||
>>> [f(3*x, x+1) for x in range(1,10)]
|
||||
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
|
||||
"""
|
||||
|
||||
def compose_two(f1: t.Callable[..., t.Any], f2: t.Callable[P, t.Any]):
|
||||
def _(*args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
return f1(f2(*args, **kwargs))
|
||||
|
||||
return _
|
||||
|
||||
return functools.reduce(compose_two, funcs)
|
||||
|
||||
|
||||
def apply(transform: t.Callable[..., t.Any]):
|
||||
"""Decorate a function with a transform function that is invoked on results returned from the decorated function.
|
||||
|
||||
```python
|
||||
@apply(reversed)
|
||||
def get_numbers(start):
|
||||
"doc for get_numbers"
|
||||
return range(start, start+3)
|
||||
list(get_numbers(4))
|
||||
# [6, 5, 4]
|
||||
```
|
||||
```python
|
||||
get_numbers.__doc__
|
||||
# 'doc for get_numbers'
|
||||
```
|
||||
"""
|
||||
|
||||
def wrap(func: t.Callable[P, t.Any]):
|
||||
return functools.wraps(func)(compose(transform, func))
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
@apply(bool)
|
||||
@suppress(FileNotFoundError)
|
||||
def _text_in_file(text: str, filename: Path):
|
||||
return any(text in line for line in filename.open())
|
||||
|
||||
|
||||
def in_docker() -> bool:
|
||||
"""Is this current environment running in docker?
|
||||
|
||||
```python
|
||||
type(in_docker())
|
||||
```
|
||||
"""
|
||||
return _dockerenv.exists() or _text_in_file("docker", _cgroup)
|
||||
|
||||
|
||||
T = t.TypeVar("T")
|
||||
K = t.TypeVar("K")
|
||||
|
||||
|
||||
def resolve_filepath(path: str) -> str:
|
||||
"""Resolve a file path to an absolute path, expand user and environment variables"""
|
||||
"""Resolve a file path to an absolute path, expand user and environment variables."""
|
||||
try:
|
||||
return resolve_user_filepath(path, None)
|
||||
except FileNotFoundError:
|
||||
@@ -166,7 +311,9 @@ def validate_is_path(maybe_path: str) -> bool:
|
||||
|
||||
|
||||
def generate_context(framework_name: str) -> _ModelContext:
|
||||
from .import_utils import is_torch_available, is_flax_available, is_tf_available
|
||||
from .import_utils import is_flax_available
|
||||
from .import_utils import is_tf_available
|
||||
from .import_utils import is_torch_available
|
||||
|
||||
framework_versions = {"transformers": pkg.get_pkg_version("transformers")}
|
||||
if is_torch_available():
|
||||
@@ -174,7 +321,7 @@ def generate_context(framework_name: str) -> _ModelContext:
|
||||
if is_tf_available():
|
||||
from bentoml._internal.frameworks.utils.tensorflow import get_tf_version
|
||||
|
||||
framework_versions["tensorflow-macos" if platform.system() == "Darwin" else "tensorflow"] = get_tf_version()
|
||||
framework_versions["tensorflow"] = get_tf_version()
|
||||
if is_flax_available():
|
||||
framework_versions.update(
|
||||
{
|
||||
@@ -186,6 +333,10 @@ def generate_context(framework_name: str) -> _ModelContext:
|
||||
return _ModelContext(framework_name=framework_name, framework_versions=framework_versions)
|
||||
|
||||
|
||||
def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> DictStrAny:
|
||||
return {"runtime": llm.runtime, "framework": "openllm"}
|
||||
|
||||
|
||||
_TOKENIZER_PREFIX = "_tokenizer_"
|
||||
|
||||
|
||||
@@ -198,6 +349,33 @@ def normalize_attrs_to_model_tokenizer_pair(**attrs: t.Any) -> tuple[DictStrAny,
|
||||
return attrs, tokenizer_attrs
|
||||
|
||||
|
||||
@_overload
|
||||
def infer_auto_class(implementation: t.Literal["pt"]) -> type[openllm.AutoLLM]:
|
||||
...
|
||||
|
||||
|
||||
@_overload
|
||||
def infer_auto_class(implementation: t.Literal["tf"]) -> type[openllm.AutoTFLLM]:
|
||||
...
|
||||
|
||||
|
||||
@_overload
|
||||
def infer_auto_class(implementation: t.Literal["flax"]) -> type[openllm.AutoFlaxLLM]:
|
||||
...
|
||||
|
||||
|
||||
def infer_auto_class(implementation: LiteralRuntime) -> type[BaseAutoLLMClass]:
|
||||
if implementation == "tf":
|
||||
from ..models.auto import AutoTFLLM as auto
|
||||
elif implementation == "flax":
|
||||
from ..models.auto import AutoFlaxLLM as auto
|
||||
elif implementation == "pt":
|
||||
from ..models.auto import AutoLLM as auto
|
||||
else:
|
||||
raise RuntimeError(f"Unknown implementation: {implementation} (supported: 'pt', 'flax', 'tf')")
|
||||
return auto
|
||||
|
||||
|
||||
# NOTE: The set marks contains a set of modules name
|
||||
# that are available above and are whitelisted
|
||||
# to be included in the extra_objects map.
|
||||
@@ -218,6 +396,7 @@ _import_structure = {
|
||||
"codegen": [],
|
||||
"dantic": [],
|
||||
"representation": ["ReprMixin"],
|
||||
"lazy": ["LazyModule"],
|
||||
"import_utils": [
|
||||
"OPTIONAL_DEPENDENCIES",
|
||||
"ENV_VARS_TRUE_VALUES",
|
||||
@@ -248,28 +427,18 @@ if t.TYPE_CHECKING:
|
||||
from . import LazyType as LazyType
|
||||
from . import analytics as analytics
|
||||
from . import bentoml_cattr as bentoml_cattr
|
||||
from . import cached_contextmanager as cached_contextmanager
|
||||
from . import codegen as codegen
|
||||
from . import configure_logging as configure_logging
|
||||
from . import configure_server_logging as configure_server_logging
|
||||
from . import copy_file_to_fs_folder as copy_file_to_fs_folder
|
||||
from . import dantic as dantic
|
||||
from . import first_not_none as first_not_none
|
||||
from . import get_debug_mode as get_debug_mode
|
||||
from . import get_quiet_mode as get_quiet_mode
|
||||
from . import gpu_count as gpu_count
|
||||
from . import lenient_issubclass as lenient_issubclass
|
||||
from . import non_intrusive_setattr as non_intrusive_setattr
|
||||
from . import pkg as pkg
|
||||
from . import reserve_free_port as reserve_free_port
|
||||
from . import resolve_user_filepath as resolve_user_filepath
|
||||
from . import set_debug_mode as set_debug_mode
|
||||
from . import set_quiet_mode as set_quiet_mode
|
||||
from . import in_notebook as in_notebook
|
||||
from . import validate_or_create_dir as validate_or_create_dir
|
||||
from . import validate_is_path as validate_is_path
|
||||
from . import resolve_filepath as resolve_filepath
|
||||
from . import normalize_attrs_to_model_tokenizer_pair as normalize_attrs_to_model_tokenizer_pair
|
||||
from . import generate_context as generate_context
|
||||
from . import field_env_key as field_env_key
|
||||
from . import validate_or_create_dir as validate_or_create_dir
|
||||
from .import_utils import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES
|
||||
from .import_utils import OPTIONAL_DEPENDENCIES as OPTIONAL_DEPENDENCIES
|
||||
from .import_utils import DummyMetaclass as DummyMetaclass
|
||||
@@ -279,6 +448,9 @@ if t.TYPE_CHECKING:
|
||||
from .import_utils import is_datasets_available as is_datasets_available
|
||||
from .import_utils import is_einops_available as is_einops_available
|
||||
from .import_utils import is_flax_available as is_flax_available
|
||||
from .import_utils import is_jupyter_available as is_jupyter_available
|
||||
from .import_utils import is_jupytext_available as is_jupytext_available
|
||||
from .import_utils import is_notebook_available as is_notebook_available
|
||||
from .import_utils import is_peft_available as is_peft_available
|
||||
from .import_utils import is_tf_available as is_tf_available
|
||||
from .import_utils import is_torch_available as is_torch_available
|
||||
@@ -287,10 +459,6 @@ if t.TYPE_CHECKING:
|
||||
from .import_utils import is_triton_available as is_triton_available
|
||||
from .import_utils import require_backends as require_backends
|
||||
from .import_utils import requires_dependencies as requires_dependencies
|
||||
from .import_utils import is_jupyter_available as is_jupyter_available
|
||||
from .import_utils import is_jupytext_available as is_jupytext_available
|
||||
from .import_utils import is_notebook_available as is_notebook_available
|
||||
from .lazy import LazyModule as LazyModule
|
||||
from .representation import ReprMixin as ReprMixin
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -11,13 +11,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Telemetry related for OpenLLM tracking.
|
||||
"""Telemetry related for OpenLLM tracking.
|
||||
|
||||
Users can disable this with OPENLLM_DO_NOT_TRACK envvar.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import os
|
||||
|
||||
@@ -13,11 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import linecache
|
||||
import logging
|
||||
import os
|
||||
import string
|
||||
import types
|
||||
import typing as t
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
|
||||
import orjson
|
||||
@@ -28,17 +31,16 @@ if t.TYPE_CHECKING:
|
||||
|
||||
import openllm
|
||||
|
||||
DictStrAny = dict[str, t.Any]
|
||||
ListStr = list[str]
|
||||
from .._types import AnyCallable
|
||||
from .._types import DictStrAny
|
||||
from .._types import ListStr
|
||||
from .._types import P
|
||||
|
||||
from attr import _make_method
|
||||
PartialAny = functools.partial[t.Any]
|
||||
else:
|
||||
# NOTE: Using internal API from attr here, since we are actually
|
||||
# allowing subclass of openllm.LLMConfig to become 'attrs'-ish
|
||||
from attr._make import _make_method
|
||||
|
||||
DictStrAny = dict
|
||||
ListStr = list
|
||||
PartialAny = functools.partial
|
||||
|
||||
_T = t.TypeVar("_T", bound=t.Callable[..., t.Any])
|
||||
|
||||
@@ -53,11 +55,12 @@ class ModelNameFormatter(string.Formatter):
|
||||
model_keyword: t.LiteralString = "__model_name__"
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
"""The formatter that extends model_name to be formatted the 'service.py'."""
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
|
||||
def vformat(self, format_string: str) -> str:
|
||||
return super().vformat(format_string, (), {self.model_keyword: self.model_name})
|
||||
def vformat(self, format_string: str, *args: t.Any, **attrs: t.Any) -> t.LiteralString:
|
||||
return t.cast("t.LiteralString", super().vformat(format_string, (), {self.model_keyword: self.model_name}))
|
||||
|
||||
def can_format(self, value: str) -> bool:
|
||||
try:
|
||||
@@ -117,9 +120,7 @@ _sentinel = object()
|
||||
|
||||
|
||||
def has_own_attribute(cls: type[t.Any], attrib_name: t.Any):
|
||||
"""
|
||||
Check whether *cls* defines *attrib_name* (and doesn't just inherit it).
|
||||
"""
|
||||
"""Check whether *cls* defines *attrib_name* (and doesn't just inherit it)."""
|
||||
attr = getattr(cls, attrib_name, _sentinel)
|
||||
if attr is _sentinel:
|
||||
return False
|
||||
@@ -133,9 +134,7 @@ def has_own_attribute(cls: type[t.Any], attrib_name: t.Any):
|
||||
|
||||
|
||||
def get_annotations(cls: type[t.Any]) -> DictStrAny:
|
||||
"""
|
||||
Get annotations for *cls*.
|
||||
"""
|
||||
"""Get annotations for *cls*."""
|
||||
if has_own_attribute(cls, "__annotations__"):
|
||||
return cls.__annotations__
|
||||
|
||||
@@ -151,8 +150,7 @@ _classvar_prefixes = (
|
||||
|
||||
|
||||
def is_class_var(annot: str | t.Any) -> bool:
|
||||
"""
|
||||
Check whether *annot* is a typing.ClassVar.
|
||||
"""Check whether *annot* is a typing.ClassVar.
|
||||
|
||||
The string comparison hack is used to avoid evaluating all string
|
||||
annotations which would put attrs-based classes at a performance
|
||||
@@ -168,16 +166,14 @@ def is_class_var(annot: str | t.Any) -> bool:
|
||||
|
||||
|
||||
def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T:
|
||||
"""
|
||||
Add __module__ and __qualname__ to a *method* if possible.
|
||||
"""
|
||||
"""Add __module__ and __qualname__ to a *method* if possible."""
|
||||
try:
|
||||
method_or_cls.__module__ = cls.__module__
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
method_or_cls.__qualname__ = ".".join((cls.__qualname__, method_or_cls.__name__))
|
||||
method_or_cls.__qualname__ = f"{cls.__qualname__}.{method_or_cls.__name__}"
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -191,6 +187,64 @@ def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str
|
||||
return method_or_cls
|
||||
|
||||
|
||||
def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = ""):
|
||||
"""Exec the script with the given global (globs) and local (locs) variables."""
|
||||
bytecode = compile(script, filename, "exec")
|
||||
eval(bytecode, globs, locs) # noqa: S307
|
||||
|
||||
|
||||
# ported from attrs
|
||||
def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> AnyCallable:
|
||||
"""Create the method with the script given and return the method object."""
|
||||
locs: DictStrAny = {}
|
||||
|
||||
# In order of debuggers like PDB being able to step through the code,
|
||||
# we add a fake linecache entry.
|
||||
count = 1
|
||||
base_filename = filename
|
||||
while True:
|
||||
linecache_tuple = (
|
||||
len(script),
|
||||
None,
|
||||
script.splitlines(True),
|
||||
filename,
|
||||
)
|
||||
old_val = linecache.cache.setdefault(filename, linecache_tuple)
|
||||
if old_val == linecache_tuple:
|
||||
break
|
||||
else:
|
||||
filename = f"{base_filename[:-1]}-{count}>"
|
||||
count += 1
|
||||
|
||||
_compile_and_eval(script, globs, locs, filename)
|
||||
|
||||
return locs[name]
|
||||
|
||||
|
||||
def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]):
|
||||
"""Create a tuple subclass to hold class attributes.
|
||||
|
||||
The subclass is a bare tuple with properties for names.
|
||||
|
||||
class MyClassAttributes(tuple):
|
||||
__slots__ = ()
|
||||
x = property(itemgetter(0))
|
||||
"""
|
||||
attr_class_name = f"{cls_name}Attributes"
|
||||
attr_class_template = [
|
||||
f"class {attr_class_name}(tuple):",
|
||||
" __slots__ = ()",
|
||||
]
|
||||
if attr_names:
|
||||
for i, attr_name in enumerate(attr_names):
|
||||
attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))")
|
||||
else:
|
||||
attr_class_template.append(" pass")
|
||||
globs: DictStrAny = {"_attrs_itemgetter": itemgetter, "_attrs_property": property}
|
||||
_compile_and_eval("\n".join(attr_class_template), globs)
|
||||
return globs[attr_class_name]
|
||||
|
||||
|
||||
def generate_unique_filename(cls: type[t.Any], func_name: str):
|
||||
return f"<{cls.__name__} generated {func_name} {cls.__module__}." f"{getattr(cls, '__qualname__', cls.__name__)}>"
|
||||
|
||||
@@ -203,7 +257,7 @@ def generate_function(
|
||||
globs: dict[str, t.Any],
|
||||
annotations: dict[str, t.Any] | None = None,
|
||||
):
|
||||
from . import DEBUG
|
||||
from . import SHOW_CODEGEN
|
||||
|
||||
script = "def %s(%s):\n %s\n" % (
|
||||
func_name,
|
||||
@@ -214,7 +268,7 @@ def generate_function(
|
||||
if annotations:
|
||||
meth.__annotations__ = annotations
|
||||
|
||||
if DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3:
|
||||
if SHOW_CODEGEN:
|
||||
logger.info("Generated script for %s:\n\n%s", typ, script)
|
||||
|
||||
return meth
|
||||
@@ -227,7 +281,8 @@ def make_env_transformer(
|
||||
default_callback: t.Callable[[str, t.Any], t.Any] | None = None,
|
||||
globs: DictStrAny | None = None,
|
||||
):
|
||||
from . import dantic, field_env_key
|
||||
from . import dantic
|
||||
from . import field_env_key
|
||||
|
||||
def identity(_: str, x_value: t.Any) -> t.Any:
|
||||
return x_value
|
||||
@@ -268,3 +323,35 @@ def make_env_transformer(
|
||||
globs=globs,
|
||||
annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann},
|
||||
)
|
||||
|
||||
|
||||
def gen_sdk(func: t.Callable[P, t.Any], name: str | None = None, **attrs: t.Any):
|
||||
from .representation import ReprMixin
|
||||
|
||||
if name is None:
|
||||
name = func.__name__.strip("_")
|
||||
|
||||
_signatures = inspect.signature(func).parameters
|
||||
|
||||
def _repr(self: ReprMixin) -> str:
|
||||
return f"<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>"
|
||||
|
||||
def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]:
|
||||
return ((k, _signatures[k].annotation) for k in self.__repr_keys__)
|
||||
|
||||
return functools.update_wrapper(
|
||||
types.new_class(
|
||||
name,
|
||||
(PartialAny, ReprMixin),
|
||||
exec_body=lambda ns: ns.update(
|
||||
{
|
||||
"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]),
|
||||
"__repr_args__": _repr_args,
|
||||
"__repr__": _repr,
|
||||
"__doc__": inspect.cleandoc(t.cast(str, func.__doc__)),
|
||||
"__module__": "openllm",
|
||||
}
|
||||
),
|
||||
)(func, **attrs),
|
||||
func,
|
||||
)
|
||||
|
||||
@@ -14,71 +14,40 @@
|
||||
"""A shim provides usable transition from pydantic to attrs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
from enum import Enum
|
||||
|
||||
import attr
|
||||
import click
|
||||
import sys
|
||||
import click_option_group as cog
|
||||
import inflection
|
||||
import orjson
|
||||
from click import ParamType, shell_completion as sc, types as click_types
|
||||
from click import ParamType
|
||||
from click import shell_completion as sc
|
||||
from click import types as click_types
|
||||
|
||||
import openllm
|
||||
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if hasattr(t, "get_overloads"):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import _ValidatorType
|
||||
|
||||
from .._types import ClickFunctionWrapper
|
||||
from .._types import F
|
||||
from .._types import O_co
|
||||
from .._types import P
|
||||
from .._types import ListAny
|
||||
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
|
||||
@overload
|
||||
def attrs_to_options(
|
||||
name: str,
|
||||
field: attr.Attribute[t.Any],
|
||||
model_name: str,
|
||||
typ: type[t.Any] | None = None,
|
||||
suffix_generation: bool = False,
|
||||
) -> F[..., F[..., openllm.LLMConfig]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def attrs_to_options( # type: ignore (overlapping overload)
|
||||
name: str,
|
||||
field: attr.Attribute[O_co],
|
||||
model_name: str,
|
||||
typ: type[t.Any] | None = None,
|
||||
suffix_generation: bool = False,
|
||||
) -> F[..., F[P, O_co]]:
|
||||
...
|
||||
|
||||
|
||||
def attrs_to_options(
|
||||
name: str,
|
||||
field: attr.Attribute[t.Any],
|
||||
model_name: str,
|
||||
typ: type[t.Any] | None = None,
|
||||
suffix_generation: bool = False,
|
||||
) -> t.Callable[..., ClickFunctionWrapper[..., t.Any]]:
|
||||
):
|
||||
# TODO: support parsing nested attrs class and Union
|
||||
envvar = field.metadata["env"]
|
||||
dasherized = inflection.dasherize(name)
|
||||
@@ -86,6 +55,8 @@ def attrs_to_options(
|
||||
|
||||
if typ in (None, attr.NOTHING):
|
||||
typ = field.type
|
||||
if typ is None:
|
||||
raise RuntimeError(f"Failed to parse type for {name}")
|
||||
|
||||
full_option_name = f"--{dasherized}"
|
||||
if field.type is bool:
|
||||
@@ -116,7 +87,7 @@ def env_converter(value: t.Any, env: str | None = None) -> t.Any:
|
||||
try:
|
||||
return orjson.loads(value.lower())
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}")
|
||||
raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
|
||||
return value
|
||||
|
||||
|
||||
@@ -132,14 +103,16 @@ def Field(
|
||||
use_default_converter: bool = True,
|
||||
**attrs: t.Any,
|
||||
):
|
||||
"""A decorator that extends attr.field with additional arguments, which provides the same
|
||||
interface as pydantic's Field.
|
||||
"""A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
|
||||
|
||||
By default, if both validator and ge are provided, then then ge will be
|
||||
piped into first, then all of the other validator will be run afterwards.
|
||||
|
||||
Args:
|
||||
default: The default value for ``dantic.Field``. Defaults to ``None``.
|
||||
ge: Greater than or equal to. Defaults to None.
|
||||
le: Less than or equal to. Defaults to None.
|
||||
validator: Optional attrs-compatible validators type. Default to None
|
||||
description: the documentation for the field. Defaults to None.
|
||||
env: the environment variable to read from. Defaults to None.
|
||||
auto_default: a bool indicating whether to use the default value as the environment.
|
||||
@@ -150,7 +123,7 @@ def Field(
|
||||
to True. If set to False, then the default converter will not be used.
|
||||
The default converter converts a given value from the environment variable
|
||||
for this given Field.
|
||||
**kwargs: The rest of the arguments are passed to attr.field
|
||||
**attrs: The rest of the arguments are passed to attr.field
|
||||
"""
|
||||
metadata = attrs.pop("metadata", {})
|
||||
if description is None:
|
||||
@@ -205,13 +178,14 @@ def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType]:
|
||||
"""
|
||||
from . import lenient_issubclass
|
||||
|
||||
assert t.get_origin(field_type) is not t.Union, "Unions are not supported"
|
||||
if t.get_origin(field_type) is t.Union:
|
||||
raise NotImplementedError("Unions are not supported")
|
||||
# enumeration strings or other Enum derivatives
|
||||
if lenient_issubclass(field_type, Enum):
|
||||
return EnumChoice(enum=field_type, case_sensitive=True)
|
||||
# literals are enum-like with way less functionality
|
||||
if is_literal(field_type):
|
||||
return LiteralChoice(enum=field_type, case_sensitive=True)
|
||||
return LiteralChoice(value=field_type, case_sensitive=True)
|
||||
# modules, classes, functions
|
||||
if is_typing(field_type):
|
||||
return ModuleType()
|
||||
@@ -248,6 +222,7 @@ def is_typing(field_type: type) -> bool:
|
||||
|
||||
def is_literal(field_type: type) -> bool:
|
||||
"""Checks whether the given field type is a Literal type or not.
|
||||
|
||||
Literals are weird: isinstance and subclass do not work, so you compare
|
||||
the origin with the Literal declaration itself.
|
||||
|
||||
@@ -266,63 +241,76 @@ class ModuleType(ParamType):
|
||||
|
||||
def _import_object(self, value: str) -> t.Any:
|
||||
module_name, class_name = value.rsplit(".", maxsplit=1)
|
||||
assert all(s.isidentifier() for s in module_name.split(".")), f"'{value}' is not a valid module name"
|
||||
assert class_name.isidentifier(), f"Variable '{class_name}' is not a valid identifier"
|
||||
if not all(s.isidentifier() for s in module_name.split(".")):
|
||||
raise ValueError(f"'{value}' is not a valid module name")
|
||||
if not class_name.isidentifier():
|
||||
raise ValueError(f"Variable '{class_name}' is not a valid identifier")
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
if class_name:
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError:
|
||||
raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.")
|
||||
raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") from None
|
||||
return None
|
||||
|
||||
def convert(self, value: str, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
return self._import_object(value)
|
||||
return value
|
||||
except Exception as exc:
|
||||
self.fail(f"'{value}' is not a valid object ({type(exc)}: {str(exc)})", param, ctx)
|
||||
self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx)
|
||||
|
||||
|
||||
class EnumChoice(click.Choice):
|
||||
name = "enum"
|
||||
|
||||
def __init__(self, enum: Enum, case_sensitive: bool = False):
|
||||
"""Enum type support for click that extends ``click.Choice``.
|
||||
|
||||
Args:
|
||||
enum: Given enum
|
||||
case_sensitive: Whether this choice should be case case_sensitive.
|
||||
"""
|
||||
self.mapping = enum
|
||||
self.internal_type = enum
|
||||
super().__init__([e.name for e in self.mapping], case_sensitive)
|
||||
self.internal_type = type(enum)
|
||||
choices: ListAny = [e.name for e in enum.__class__]
|
||||
super().__init__(choices, case_sensitive)
|
||||
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum:
|
||||
if isinstance(value, self.internal_type):
|
||||
return value
|
||||
result = super().convert(value, param, ctx)
|
||||
if isinstance(result, str):
|
||||
result = self.mapping[result]
|
||||
result = self.internal_type[result]
|
||||
return result
|
||||
|
||||
|
||||
class LiteralChoice(EnumChoice):
|
||||
name = "literal"
|
||||
|
||||
def __init__(self, enum: t.LiteralString, case_sensitive: bool = False):
|
||||
def __init__(self, value: t.Any, case_sensitive: bool = False):
|
||||
"""Literal support for click."""
|
||||
# expect every literal value to belong to the same primitive type
|
||||
values = list(enum.__args__)
|
||||
values = list(value.__args__)
|
||||
item_type = type(values[0])
|
||||
assert all(isinstance(v, item_type) for v in values), f"Field {enum} contains items of different types"
|
||||
if not all(isinstance(v, item_type) for v in values):
|
||||
raise ValueError(f"Field {value} contains items of different types.")
|
||||
self.internal_type = item_type
|
||||
self.mapping = {str(v): v for v in values}
|
||||
super(EnumChoice, self).__init__(list(self.mapping.keys()), case_sensitive)
|
||||
|
||||
|
||||
def allows_multiple(field_type: t.Any) -> bool:
|
||||
def allows_multiple(field_type: type) -> bool:
|
||||
"""Checks whether the current type allows for multiple arguments to be provided as input or not.
|
||||
|
||||
For containers, it exploits click's support for lists and such to use the same option multiple times
|
||||
to create a complex object: `python run.py --subsets train --subsets test`
|
||||
# becomes `subsets: ["train", "test"]`.
|
||||
|
||||
Args:
|
||||
field_type (type): pydantic type
|
||||
field_type: pydantic type.
|
||||
|
||||
Returns:
|
||||
bool: true if it's a composite field (lists, containers and so on), false otherwise
|
||||
@@ -360,8 +348,7 @@ def is_mapping(field_type: type) -> bool:
|
||||
|
||||
|
||||
def is_container(field_type: type) -> bool:
|
||||
"""Checks whether the current type is a container type ('contains' other types), like
|
||||
lists and tuples.
|
||||
"""Checks whether the current type is a container type ('contains' other types), like lists and tuples.
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
@@ -391,12 +378,13 @@ def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType
|
||||
Returns:
|
||||
ParamType | tuple[ParamType]: single click-compatible type or a tuple
|
||||
"""
|
||||
assert is_container(field_type), "Field type is not a container"
|
||||
if not is_container(field_type):
|
||||
raise ValueError("Field type is not a container type.")
|
||||
args = t.get_args(field_type)
|
||||
# Early out for untyped containers: standard lists, tuples, List[Any]
|
||||
# Use strings when the type is unknown, avoid click's type guessing
|
||||
if len(args) == 0:
|
||||
return str
|
||||
return click_types.convert_type(str)
|
||||
# Early out for homogenous containers: Tuple[int], List[str]
|
||||
if len(args) == 1:
|
||||
return parse_single_arg(args[0])
|
||||
@@ -409,6 +397,7 @@ def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType
|
||||
|
||||
def parse_single_arg(arg: type) -> ParamType:
|
||||
"""Returns the click-compatible type for container origin types.
|
||||
|
||||
In this case, returns string when it's not inferrable, a JSON for mappings
|
||||
and the original type itself in every other case (ints, floats and so on).
|
||||
Bytes is a special case, not natively handled by click.
|
||||
@@ -421,13 +410,13 @@ def parse_single_arg(arg: type) -> ParamType:
|
||||
"""
|
||||
# When we don't know the type, we choose 'str'
|
||||
if arg is t.Any:
|
||||
return str
|
||||
return click_types.convert_type(str)
|
||||
# For containers and nested models, we use JSON
|
||||
if is_container(arg):
|
||||
return JsonType()
|
||||
if openllm.utils.lenient_issubclass(arg, bytes):
|
||||
return BytesType()
|
||||
return arg
|
||||
return click_types.convert_type(arg)
|
||||
|
||||
|
||||
class BytesType(ParamType):
|
||||
@@ -439,7 +428,7 @@ class BytesType(ParamType):
|
||||
try:
|
||||
return str.encode(value)
|
||||
except Exception as exc:
|
||||
self.fail(f"'{value}' is not a valid string ({str(exc)})", param, ctx)
|
||||
self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx)
|
||||
|
||||
|
||||
CYGWIN = sys.platform.startswith("cygwin")
|
||||
@@ -470,17 +459,14 @@ class CudaValueType(ParamType):
|
||||
return var
|
||||
|
||||
def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
"""Return a list of
|
||||
:class:`~click.shell_completion.CompletionItem` objects for the
|
||||
incomplete value. Most types do not provide completions, but
|
||||
some do, and this allows custom types to provide custom
|
||||
completions as well.
|
||||
"""Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
|
||||
|
||||
:param ctx: Invocation context for this command.
|
||||
:param param: The parameter that is requesting completion.
|
||||
:param incomplete: Value being completed. May be empty.
|
||||
Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well.
|
||||
|
||||
.. versionadded:: 8.0
|
||||
Args:
|
||||
ctx: Invocation context for this command.
|
||||
param: The parameter that is requesting completion.
|
||||
incomplete: Value being completed. May be empty.
|
||||
"""
|
||||
from ..utils import gpu_count
|
||||
|
||||
@@ -506,6 +492,7 @@ class CudaValueType(ParamType):
|
||||
return tuple(self.typ(x, param, ctx) for x in value.split(","))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""CUDA is a click.STRING extension."""
|
||||
return "STRING"
|
||||
|
||||
|
||||
@@ -516,6 +503,11 @@ class JsonType(ParamType):
|
||||
name = "json"
|
||||
|
||||
def __init__(self, should_load: bool = True) -> None:
|
||||
"""Support JSON type for click.ParamType.
|
||||
|
||||
Args:
|
||||
should_load: Whether to load the JSON. Default to True. If False, the value won't be converted.
|
||||
"""
|
||||
super().__init__()
|
||||
self.should_load = should_load
|
||||
|
||||
@@ -525,4 +517,4 @@ class JsonType(ParamType):
|
||||
try:
|
||||
return orjson.loads(value)
|
||||
except orjson.JSONDecodeError as exc:
|
||||
self.fail(f"'{value}' is not a valid JSON string ({str(exc)})", param, ctx)
|
||||
self.fail(f"'{value}' is not a valid JSON string ({exc!s})", param, ctx)
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ..utils import DummyMetaclass
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ..utils import DummyMetaclass
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ..utils import DummyMetaclass
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ..utils import DummyMetaclass
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ..utils import DummyMetaclass
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
from ..utils import DummyMetaclass
|
||||
|
||||
@@ -12,17 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons.
|
||||
"""
|
||||
"""Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
from abc import ABCMeta
|
||||
from collections import OrderedDict
|
||||
@@ -38,15 +36,21 @@ from .representation import ReprMixin
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if hasattr(t, "get_overloads"):
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]]
|
||||
from .._types import LiteralRuntime
|
||||
from .._types import P
|
||||
|
||||
class _AnnotatedLazyLoader(LazyLoader):
|
||||
DEFAULT_PROMPT_TEMPLATE: t.LiteralString | None | t.Callable[..., t.LiteralString]
|
||||
|
||||
else:
|
||||
_AnnotatedLazyLoader = LazyLoader
|
||||
BackendOrderredDict = OrderedDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -188,7 +192,7 @@ def is_tf_available():
|
||||
_tf_available = _tf_version is not None
|
||||
if _tf_available:
|
||||
if _tf_version and version.parse(_tf_version) < version.parse("2"):
|
||||
logger.info(f"TensorFlow found but with version {_tf_version}. OpenLLM only supports TF 2.x")
|
||||
logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version)
|
||||
_tf_available = False
|
||||
else:
|
||||
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
||||
@@ -321,8 +325,9 @@ BACKENDS_MAPPING = BackendOrderredDict(
|
||||
|
||||
|
||||
class DummyMetaclass(ABCMeta):
|
||||
"""Metaclass for dummy object. It will raises ImportError
|
||||
generated by ``require_backends`` if users try to access attributes from given class
|
||||
"""Metaclass for dummy object.
|
||||
|
||||
It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class.
|
||||
"""
|
||||
|
||||
_backends: t.List[str]
|
||||
@@ -368,7 +373,7 @@ class EnvVarMixin(ReprMixin):
|
||||
bettertransformer: str
|
||||
runtime: t.Literal["ggml", "transformers"]
|
||||
|
||||
framework_value: t.Literal["pt", "tf", "flax"]
|
||||
framework_value: LiteralRuntime
|
||||
quantize_value: str | None
|
||||
bettertransformer_value: str | None
|
||||
runtime_value: t.Literal["ggml", "transformers"]
|
||||
@@ -385,17 +390,17 @@ class EnvVarMixin(ReprMixin):
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["bettertransformer"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal['runtime']) -> str: ...
|
||||
def __getitem__(self, item: t.Literal["runtime"]) -> str: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal['framework_value']) -> t.Literal['pt', 'tf', 'flax']: ...
|
||||
def __getitem__(self, item: t.Literal["framework_value"]) -> LiteralRuntime: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal['quantize_value']) -> str | None: ...
|
||||
def __getitem__(self, item: t.Literal["quantize_value"]) -> str | None: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal['model_id_value']) -> str | None: ...
|
||||
def __getitem__(self, item: t.Literal["model_id_value"]) -> str | None: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal['bettertransformer_value']) -> str | None: ...
|
||||
def __getitem__(self, item: t.Literal["bettertransformer_value"]) -> str | None: ...
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal['runtime_value']) -> t.Literal['ggml', 'transformers']: ...
|
||||
def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]: ...
|
||||
# fmt: on
|
||||
def __getitem__(self, item: str | t.Any) -> t.Any:
|
||||
if hasattr(self, item):
|
||||
@@ -409,8 +414,8 @@ class EnvVarMixin(ReprMixin):
|
||||
quantize: t.LiteralString | None = None,
|
||||
runtime: t.Literal["ggml", "transformers"] = "transformers",
|
||||
):
|
||||
from .._configuration import field_env_key
|
||||
from . import codegen
|
||||
from .._configuration import field_env_key
|
||||
|
||||
model_name = inflection.underscore(model_name)
|
||||
|
||||
@@ -464,5 +469,5 @@ class EnvVarMixin(ReprMixin):
|
||||
return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING")
|
||||
|
||||
@property
|
||||
def module(self) -> LazyLoader:
|
||||
return LazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")
|
||||
def module(self):
|
||||
return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import importlib.machinery
|
||||
import itertools
|
||||
@@ -40,10 +39,9 @@ _reserved_namespace = {"__openllm_special__", "__openllm_migration__"}
|
||||
|
||||
|
||||
class LazyModule(types.ModuleType):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
This is a direct port from transformers.utils.import_utils._LazyModule for
|
||||
backwards compatibility with transformers < 4.18
|
||||
"""Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
|
||||
This is a direct port from transformers.utils.import_utils._LazyModule for backwards compatibility with transformers < 4.18.
|
||||
|
||||
This is an extension a more powerful LazyLoader.
|
||||
"""
|
||||
@@ -56,8 +54,22 @@ class LazyModule(types.ModuleType):
|
||||
module_file: str,
|
||||
import_structure: dict[str, list[str]],
|
||||
module_spec: importlib.machinery.ModuleSpec | None = None,
|
||||
doc: str | None = None,
|
||||
extra_objects: dict[str, t.Any] | None = None,
|
||||
):
|
||||
"""Lazily load this module as an object.
|
||||
|
||||
It does instantiate a __all__ and __dir__ for IDE support
|
||||
|
||||
Args:
|
||||
name: module name
|
||||
module_file: the given file. Often default to 'globals()['__file__']'
|
||||
import_structure: A dictionary of module and its corresponding attributes that can be loaded from given 'module'
|
||||
module_spec: __spec__ of the lazily loaded module
|
||||
doc: Optional docstring for this module.
|
||||
extra_objects: Any additional objects that this module can also be accessed. Useful for additional metadata as well
|
||||
as any locals() functions
|
||||
"""
|
||||
super().__init__(name)
|
||||
self._modules = set(import_structure.keys())
|
||||
self._class_to_module: dict[str, str] = {}
|
||||
@@ -70,24 +82,22 @@ class LazyModule(types.ModuleType):
|
||||
self.__file__ = module_file
|
||||
self.__spec__ = module_spec
|
||||
self.__path__ = [os.path.dirname(module_file)]
|
||||
self.__doc__ = doc
|
||||
self._objects = _extra_objects
|
||||
self._name = name
|
||||
self._import_structure = import_structure
|
||||
|
||||
# Needed for autocompletion in an IDE
|
||||
def __dir__(self):
|
||||
"""Needed for autocompletion in an IDE."""
|
||||
result = t.cast("list[str]", super().__dir__())
|
||||
# The elements of self.__all__ that are submodules may or
|
||||
# may not be in the dir already, depending on whether
|
||||
# they have been accessed or not. So we only add the
|
||||
# elements of self.__all__ that are not already in the dir.
|
||||
for attribute in self.__all__:
|
||||
if attribute not in result:
|
||||
result.append(attribute)
|
||||
return result
|
||||
return result + [i for i in self.__all__ if i not in result]
|
||||
|
||||
def __getitem__(self, key: str) -> t.Any:
|
||||
# currently, this is reserved to only internal uses and users shouldn't use this.
|
||||
"""This is reserved to only internal uses and users shouldn't use this."""
|
||||
if self._objects.get("__openllm_special__") is None:
|
||||
raise UsageNotAllowedError(f"'{self._name}' is not allowed to be used as a dict.")
|
||||
_special_mapping = self._objects.get("__openllm_special__", {})
|
||||
@@ -101,6 +111,10 @@ class LazyModule(types.ModuleType):
|
||||
raise KeyError(f"Failed to lookup '{key}' in '{self._name}'") from e
|
||||
|
||||
def __getattr__(self, name: str) -> t.Any:
|
||||
"""Equivocal __getattr__ implementation.
|
||||
|
||||
It checks from _objects > _modules and does it recursively.
|
||||
"""
|
||||
if name in _reserved_namespace:
|
||||
raise ForbiddenAttributeError(
|
||||
f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified."
|
||||
@@ -111,6 +125,7 @@ class LazyModule(types.ModuleType):
|
||||
warnings.warn(
|
||||
f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return getattr(self, cur_value)
|
||||
if name in self._objects:
|
||||
@@ -136,4 +151,5 @@ class LazyModule(types.ModuleType):
|
||||
) from e
|
||||
|
||||
def __reduce__(self):
|
||||
"""This is to ensure any given module is pickle-able."""
|
||||
return (self.__class__, (self._name, self.__file__, self._import_structure))
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from abc import abstractmethod
|
||||
|
||||
@@ -27,6 +26,7 @@ if t.TYPE_CHECKING:
|
||||
|
||||
class ReprMixin:
|
||||
"""This class display possible representation of given class.
|
||||
|
||||
It can be used for implementing __rich_pretty__ and __pretty__ methods in the future.
|
||||
Most subclass needs to implement a __repr_keys__ property.
|
||||
|
||||
@@ -41,12 +41,20 @@ class ReprMixin:
|
||||
"""This can be overriden by base class using this mixin."""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""The `__repr__` for any subclass of Mixin.
|
||||
|
||||
It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict.
|
||||
"""
|
||||
from . import bentoml_cattr
|
||||
|
||||
serialized = {k: bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}
|
||||
return f"{self.__class__.__name__} {orjson.dumps(serialized, option=orjson.OPT_INDENT_2).decode()}"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""The string representation of the given Mixin subclass.
|
||||
|
||||
It will contains all of the attributes from __repr_keys__
|
||||
"""
|
||||
return self.__repr_str__(" ")
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
@@ -54,7 +62,12 @@ class ReprMixin:
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr_str__(self, join_str: str) -> str:
|
||||
return join_str.join(repr(v) if a is None else f"{a}={repr(v)}" for a, v in self.__repr_args__())
|
||||
"""To be used with __str__."""
|
||||
return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__())
|
||||
|
||||
def __repr_args__(self) -> ReprArgs:
|
||||
"""This can also be overriden by base class using this mixin.
|
||||
|
||||
By default it does a getattr of the current object from __repr_keys__.
|
||||
"""
|
||||
return ((k, getattr(self, k)) for k in self.__repr_keys__)
|
||||
|
||||
@@ -12,8 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
The actual client implementation. Use ``openllm.client`` instead.
|
||||
"""The actual client implementation.
|
||||
|
||||
Use ``openllm.client`` instead.
|
||||
This holds the implementation of the client, which is used to communicate with the
|
||||
OpenLLM server. It is used to send requests to the server, and receive responses.
|
||||
"""
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
@@ -42,7 +41,7 @@ class PromptTemplate:
|
||||
input_variables: t.Sequence[str]
|
||||
|
||||
def to_str(self, __partial_dict__: PartialDict | None = None, **attrs: str) -> str:
|
||||
"""Generate a prompt from the template and input variables"""
|
||||
"""Generate a prompt from the template and input variables."""
|
||||
if __partial_dict__:
|
||||
return _default_formatter.vformat(self.template, (), __partial_dict__)
|
||||
if not attrs:
|
||||
@@ -58,7 +57,7 @@ class PromptTemplate:
|
||||
|
||||
@classmethod
|
||||
def from_default(cls, model: str, /, **prompt_attrs: t.Any) -> PromptTemplate:
|
||||
template = getattr(openllm.utils.EnvVarMixin(model).module, "DEFAULT_PROMPT_TEMPLATE")
|
||||
template = openllm.utils.EnvVarMixin(model).module.DEFAULT_PROMPT_TEMPLATE
|
||||
if template is None:
|
||||
raise ValueError(f"Model {model} does not have a default prompt template.")
|
||||
if callable(template):
|
||||
|
||||
@@ -11,6 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Client that supports REST/gRPC protocol to interact with a LLMServer."""
|
||||
|
||||
from .grpc import AsyncGrpcClient as AsyncGrpcClient
|
||||
from .grpc import GrpcClient as GrpcClient
|
||||
from .http import AsyncHTTPClient as AsyncHTTPClient
|
||||
from .http import HTTPClient as HTTPClient
|
||||
|
||||
@@ -13,29 +13,30 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import typing as t
|
||||
from abc import abstractmethod
|
||||
from http import HTTPStatus
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
import logging
|
||||
|
||||
|
||||
# NOTE: We need to do this so that overload can register
|
||||
# correct overloads to typing registry
|
||||
if hasattr(t, "get_overloads"):
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import overload
|
||||
else:
|
||||
from typing_extensions import overload
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
from openllm.models.auto.factory import _BaseAutoLLMClass
|
||||
from openllm._types import LiteralRuntime
|
||||
|
||||
class AnnotatedClient(bentoml.client.Client):
|
||||
def health(self, *args: t.Any, **attrs: t.Any) -> t.Any:
|
||||
@@ -44,12 +45,6 @@ if t.TYPE_CHECKING:
|
||||
async def async_health(self) -> t.Any:
|
||||
...
|
||||
|
||||
def call(self, name: str, inputs: t.Any, **attrs: t.Any) -> t.Any:
|
||||
...
|
||||
|
||||
async def acall(self, name: str, inputs: t.Any, **attrs: t.Any) -> t.Any:
|
||||
...
|
||||
|
||||
def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]:
|
||||
...
|
||||
|
||||
@@ -70,7 +65,10 @@ def in_async_context() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class ClientMixin:
|
||||
T = t.TypeVar("T")
|
||||
|
||||
|
||||
class ClientMeta(t.Generic[T]):
|
||||
_api_version: str
|
||||
_client_class: type[bentoml.client.Client]
|
||||
|
||||
@@ -84,9 +82,9 @@ class ClientMixin:
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._address = address
|
||||
self._timeout = timeout
|
||||
assert self._host and self._port, "Make sure to setup _host and _port based on your client implementation."
|
||||
|
||||
def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"):
|
||||
"""Initialise subclass for HTTP and gRPC client type."""
|
||||
cls._client_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient
|
||||
cls._api_version = api_version
|
||||
|
||||
@@ -102,7 +100,7 @@ class ClientMixin:
|
||||
return self.__agent__
|
||||
|
||||
@property
|
||||
def _metadata(self) -> dict[str, t.Any]:
|
||||
def _metadata(self) -> T:
|
||||
if in_async_context():
|
||||
return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json()
|
||||
return self.call("metadata")
|
||||
@@ -114,7 +112,7 @@ class ClientMixin:
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def framework(self) -> t.Literal["pt", "flax", "tf"]:
|
||||
def framework(self) -> LiteralRuntime:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@@ -135,10 +133,7 @@ class ClientMixin:
|
||||
@property
|
||||
def llm(self) -> openllm.LLM[t.Any, t.Any]:
|
||||
if self.__llm__ is None:
|
||||
self.__llm__ = t.cast(
|
||||
"_BaseAutoLLMClass",
|
||||
openllm[self.framework], # type: ignore (internal API)
|
||||
).for_model(self.model_name)
|
||||
self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name)
|
||||
return self.__llm__
|
||||
|
||||
@property
|
||||
@@ -171,7 +166,7 @@ class ClientMixin:
|
||||
...
|
||||
|
||||
|
||||
class BaseClient(ClientMixin):
|
||||
class BaseClient(ClientMeta[T]):
|
||||
def health(self) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -183,22 +178,32 @@ class BaseClient(ClientMixin):
|
||||
def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]:
|
||||
...
|
||||
|
||||
def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str:
|
||||
return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(prompt, **attrs)
|
||||
@overload
|
||||
def query(self, prompt: str, *, return_attrs: t.Literal[True] = True, **attrs: t.Any) -> openllm.GenerationOutput:
|
||||
...
|
||||
|
||||
def query(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | dict[str, t.Any] | str:
|
||||
# NOTE: We set use_default_prompt_template to False for now.
|
||||
use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
|
||||
return_attrs = attrs.pop("return_attrs", False)
|
||||
return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(
|
||||
prompt, use_default_prompt_template=use_default_prompt_template, **attrs
|
||||
)
|
||||
inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
|
||||
if in_async_context():
|
||||
result = httpx.post(
|
||||
urljoin(self._address, f"/{self._api_version}/generate"),
|
||||
json=openllm.utils.bentoml_cattr.unstructure(inputs),
|
||||
json=inputs.model_dump(),
|
||||
timeout=self.timeout,
|
||||
).json()
|
||||
else:
|
||||
result = self.call("generate", inputs)
|
||||
result = self.call("generate", inputs.model_dump())
|
||||
r = self.postprocess(result)
|
||||
|
||||
if return_attrs:
|
||||
return r
|
||||
if return_raw_response:
|
||||
return openllm.utils.bentoml_cattr.unstructure(r)
|
||||
|
||||
return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
|
||||
|
||||
def ask_agent(
|
||||
@@ -235,10 +240,21 @@ class BaseClient(ClientMixin):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseAsyncClient(ClientMixin):
|
||||
class BaseAsyncClient(ClientMeta[T]):
|
||||
async def health(self) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
@overload
|
||||
async def query(
|
||||
self,
|
||||
prompt: str,
|
||||
*,
|
||||
return_attrs: t.Literal[True] = True,
|
||||
return_raw_response: bool | None = ...,
|
||||
**attrs: t.Any,
|
||||
) -> openllm.GenerationOutput:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str:
|
||||
...
|
||||
@@ -249,19 +265,21 @@ class BaseAsyncClient(ClientMixin):
|
||||
) -> dict[str, t.Any]:
|
||||
...
|
||||
|
||||
async def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str:
|
||||
async def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str | openllm.GenerationOutput:
|
||||
# NOTE: We set use_default_prompt_template to False for now.
|
||||
use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
|
||||
return_attrs = attrs.pop("return_attrs", False)
|
||||
return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(
|
||||
prompt, use_default_prompt_template=use_default_prompt_template, **attrs
|
||||
)
|
||||
inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs))
|
||||
res = await self.acall("generate", inputs)
|
||||
res = await self.acall("generate", inputs.model_dump())
|
||||
r = self.postprocess(res)
|
||||
|
||||
if return_attrs:
|
||||
return r
|
||||
if return_raw_response:
|
||||
return openllm.utils.bentoml_cattr.unstructure(r)
|
||||
|
||||
return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
|
||||
|
||||
async def ask_agent(
|
||||
@@ -273,13 +291,17 @@ class BaseAsyncClient(ClientMixin):
|
||||
agent_type: t.LiteralString = "hf",
|
||||
**attrs: t.Any,
|
||||
) -> t.Any:
|
||||
"""Async version of agent.run"""
|
||||
"""Async version of agent.run."""
|
||||
if agent_type == "hf":
|
||||
return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
|
||||
|
||||
async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
if not openllm.utils.is_transformers_supports_agent():
|
||||
raise RuntimeError(
|
||||
"This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0"
|
||||
)
|
||||
if len(args) > 1:
|
||||
raise ValueError("'args' should only take one positional argument.")
|
||||
task = kwargs.pop("task", args[0])
|
||||
@@ -293,7 +315,7 @@ class BaseAsyncClient(ClientMixin):
|
||||
|
||||
_hf_agent = self._hf_agent
|
||||
|
||||
prompt = _hf_agent.format_prompt(task)
|
||||
prompt = t.cast(str, _hf_agent.format_prompt(task))
|
||||
stop = ["Task:"]
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(self.timeout)) as client:
|
||||
response = await client.post(
|
||||
@@ -303,7 +325,7 @@ class BaseAsyncClient(ClientMixin):
|
||||
"parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop},
|
||||
},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
if response.status_code != HTTPStatus.OK:
|
||||
raise ValueError(f"Error {response.status_code}: {response.json()}")
|
||||
|
||||
result = response.json()[0]["generated_text"]
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import typing as t
|
||||
@@ -27,46 +26,51 @@ from .base import BaseClient
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import grpc_health.v1.health_pb2 as health_pb2
|
||||
from grpc_health.v1 import health_pb2
|
||||
|
||||
from bentoml.grpc.v1.service_pb2 import Response
|
||||
from openllm._types import LiteralRuntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GrpcClientMixin:
|
||||
_metadata: Response
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
@property
|
||||
def _metadata(self) -> Response:
|
||||
...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
try:
|
||||
return self._metadata.json.struct_value.fields["model_name"].string_value
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
@property
|
||||
def framework(self) -> t.Literal["pt", "flax", "tf"]:
|
||||
def framework(self) -> LiteralRuntime:
|
||||
try:
|
||||
value = self._metadata.json.struct_value.fields["framework"].string_value
|
||||
if value not in ("pt", "flax", "tf"):
|
||||
raise KeyError
|
||||
return value
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
try:
|
||||
return int(self._metadata.json.struct_value.fields["timeout"].number_value)
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
try:
|
||||
return self._metadata.json.struct_value.fields["model_id"].string_value
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
@property
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
@@ -74,7 +78,7 @@ class GrpcClientMixin:
|
||||
v = self._metadata.json.struct_value.fields["configuration"].string_value
|
||||
return orjson.loads(v)
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
if isinstance(result, dict):
|
||||
@@ -85,7 +89,7 @@ class GrpcClientMixin:
|
||||
return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True))
|
||||
|
||||
|
||||
class GrpcClient(GrpcClientMixin, BaseClient, client_type="grpc"):
|
||||
class GrpcClient(GrpcClientMixin, BaseClient["Response"], client_type="grpc"):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._host, self._port = address.split(":")
|
||||
super().__init__(address, timeout)
|
||||
@@ -94,7 +98,7 @@ class GrpcClient(GrpcClientMixin, BaseClient, client_type="grpc"):
|
||||
return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService"))
|
||||
|
||||
|
||||
class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient, client_type="grpc"):
|
||||
class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient["Response"], client_type="grpc"):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._host, self._port = address.split(":")
|
||||
super().__init__(address, timeout)
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
from urllib.parse import urlparse
|
||||
@@ -26,52 +25,63 @@ from .base import BaseAsyncClient
|
||||
from .base import BaseClient
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm._types import DictStrAny
|
||||
from openllm._types import LiteralRuntime
|
||||
else:
|
||||
DictStrAny = dict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HTTPClientMixin:
|
||||
_metadata: dict[str, t.Any]
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
@property
|
||||
def _metadata(self) -> DictStrAny:
|
||||
...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
try:
|
||||
return self._metadata["model_name"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
try:
|
||||
return self._metadata["model_name"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
@property
|
||||
def framework(self) -> t.Literal["pt", "flax", "tf"]:
|
||||
def framework(self) -> LiteralRuntime:
|
||||
try:
|
||||
return self._metadata["framework"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
try:
|
||||
return self._metadata["timeout"]
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
@property
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
try:
|
||||
return orjson.loads(self._metadata["configuration"])
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
|
||||
def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
return openllm.GenerationOutput(**result)
|
||||
|
||||
|
||||
class HTTPClient(HTTPClientMixin, BaseClient):
|
||||
class HTTPClient(HTTPClientMixin, BaseClient[DictStrAny]):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
address = address if "://" in address else "http://" + address
|
||||
self._host, self._port = urlparse(address).netloc.split(":")
|
||||
@@ -81,7 +91,7 @@ class HTTPClient(HTTPClientMixin, BaseClient):
|
||||
return self._cached.health()
|
||||
|
||||
|
||||
class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient):
|
||||
class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient[DictStrAny]):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
address = address if "://" in address else "http://" + address
|
||||
self._host, self._port = urlparse(address).netloc.split(":")
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
@@ -62,11 +61,11 @@ def make_llm_config(
|
||||
lines.append(f' __config__ = {{ {", ".join(_config_args)} }}')
|
||||
if fields is not None:
|
||||
for field, type_, default in fields:
|
||||
lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({repr(default)})")
|
||||
lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({default!r})")
|
||||
if generation_fields is not None:
|
||||
generation_lines = ["class GenerationConfig:"]
|
||||
for field, default in generation_fields:
|
||||
generation_lines.append(f" {field} = {repr(default)}")
|
||||
generation_lines.append(f" {field} = {default!r}")
|
||||
lines.extend((" " + line for line in generation_lines))
|
||||
|
||||
script = "\n".join(lines)
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""All configuration-related tests for openllm.LLMConfig. This will include testing
|
||||
for ModelEnv construction and parsing environment variables."""
|
||||
for ModelEnv construction and parsing environment variables.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
@@ -125,29 +125,20 @@ def test_complex_struct_dump(
|
||||
generation_fields=(("temperature", temperature),),
|
||||
)
|
||||
sent = cl_()
|
||||
assert (
|
||||
sent.model_dump()["field1"] == field1 and sent.model_dump()["generation_config"]["temperature"] == temperature
|
||||
)
|
||||
assert (
|
||||
sent.model_dump(flatten=True)["field1"] == field1
|
||||
and sent.model_dump(flatten=True)["temperature"] == temperature
|
||||
)
|
||||
assert sent.model_dump()["field1"] == field1
|
||||
assert sent.model_dump()["generation_config"]["temperature"] == temperature
|
||||
assert sent.model_dump(flatten=True)["field1"] == field1
|
||||
assert sent.model_dump(flatten=True)["temperature"] == temperature
|
||||
|
||||
passed = cl_(field1=input_field1, temperature=input_temperature)
|
||||
assert (
|
||||
passed.model_dump()["field1"] == input_field1
|
||||
and passed.model_dump()["generation_config"]["temperature"] == input_temperature
|
||||
)
|
||||
assert (
|
||||
passed.model_dump(flatten=True)["field1"] == input_field1
|
||||
and passed.model_dump(flatten=True)["temperature"] == input_temperature
|
||||
)
|
||||
assert passed.model_dump()["field1"] == input_field1
|
||||
assert passed.model_dump()["generation_config"]["temperature"] == input_temperature
|
||||
assert passed.model_dump(flatten=True)["field1"] == input_field1
|
||||
assert passed.model_dump(flatten=True)["temperature"] == input_temperature
|
||||
|
||||
pas_nested = cl_(generation_config={"temperature": input_temperature}, field1=input_field1)
|
||||
assert (
|
||||
pas_nested.model_dump()["field1"] == input_field1
|
||||
and pas_nested.model_dump()["generation_config"]["temperature"] == input_temperature
|
||||
)
|
||||
assert pas_nested.model_dump()["field1"] == input_field1
|
||||
assert pas_nested.model_dump()["generation_config"]["temperature"] == input_temperature
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -207,7 +198,7 @@ def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPat
|
||||
|
||||
|
||||
@given(model_settings())
|
||||
@pytest.mark.parametrize("return_dict,typ", [(True, DictStrAny), (False, transformers.GenerationConfig)])
|
||||
@pytest.mark.parametrize(("return_dict", "typ"), [(True, DictStrAny), (False, transformers.GenerationConfig)])
|
||||
def test_conversion_to_transformers(return_dict: bool, typ: type[t.Any], gen_settings: ModelSettings):
|
||||
cl_ = make_llm_config("ConversionLLM", gen_settings)
|
||||
assert isinstance(cl_().to_generation_config(return_as_dict=return_dict), typ)
|
||||
@@ -13,9 +13,55 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
import itertools
|
||||
import typing as t
|
||||
|
||||
import pytest
|
||||
|
||||
import openllm
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm._types import LiteralRuntime
|
||||
|
||||
|
||||
_FRAMEWORK_MAPPING = {"flan_t5": "google/flan-t5-small", "opt": "facebook/opt-125m"}
|
||||
_PROMPT_MAPPING = {
|
||||
"qa": "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?",
|
||||
}
|
||||
|
||||
|
||||
def parametrise_local_llm(
|
||||
model: str,
|
||||
) -> t.Generator[tuple[str, openllm.LLMRunner | openllm.LLM[t.Any, t.Any]], None, None]:
|
||||
if model not in _FRAMEWORK_MAPPING:
|
||||
pytest.skip(f"'{model}' is not yet supported in framework testing.")
|
||||
|
||||
runtime_impl: tuple[LiteralRuntime, ...] = tuple()
|
||||
if model in openllm.MODEL_MAPPING_NAMES:
|
||||
runtime_impl += ("pt",)
|
||||
if model in openllm.MODEL_FLAX_MAPPING_NAMES:
|
||||
runtime_impl += ("flax",)
|
||||
if model in openllm.MODEL_TF_MAPPING_NAMES:
|
||||
runtime_impl += ("tf",)
|
||||
|
||||
for framework, prompt in itertools.product(runtime_impl, _PROMPT_MAPPING.keys()):
|
||||
llm = openllm.Runner(
|
||||
model,
|
||||
model_id=_FRAMEWORK_MAPPING[model],
|
||||
ensure_available=True,
|
||||
implementation=framework,
|
||||
init_local=True,
|
||||
)
|
||||
yield prompt, llm
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
|
||||
if "prompt" in metafunc.fixturenames and "llm" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"prompt,llm", [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])]
|
||||
)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
|
||||
# If no tests are collected, pytest exists with code 5, which makes the CI fail.
|
||||
|
||||
@@ -13,12 +13,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
import pytest
|
||||
from openllm._llm import make_tag
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import pytest
|
||||
|
||||
|
||||
HF_INTERNAL_T5_TESTING = "hf-internal-testing/tiny-random-t5"
|
||||
|
||||
|
||||
@@ -31,21 +35,6 @@ def patch_hash_from_file(_: str, algorithm: t.LiteralString = "sha1") -> str:
|
||||
return "d88a1a40e354a0c7fa6f9055938594e6a4c712e0"
|
||||
|
||||
|
||||
def test_tag_generation_from_custom_path(
|
||||
tmp_path_factory: pytest.TempPathFactory, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
|
||||
):
|
||||
monkeypatch.setattr(openllm._llm, "generate_hash_from_file", patch_hash_from_file)
|
||||
local_path = tmp_path_factory.mktemp("local_t5")
|
||||
llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING, ensure_available=True)
|
||||
llm.save_pretrained(local_path)
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
tag = make_tag(local_path.resolve().__fspath__())
|
||||
|
||||
assert tag.version == "d88a1a40e354a0c7fa6f9055938594e6a4c712e0"
|
||||
assert "Given 'model_id" in caplog.text
|
||||
|
||||
|
||||
def test_tag_generation_quiet_log(tmp_path_factory: pytest.TempPathFactory, caplog: pytest.LogCaptureFixture):
|
||||
local_path = tmp_path_factory.mktemp("local_t5")
|
||||
llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING, ensure_available=True)
|
||||
@@ -54,12 +43,3 @@ def test_tag_generation_quiet_log(tmp_path_factory: pytest.TempPathFactory, capl
|
||||
with caplog.at_level("WARNING"):
|
||||
make_tag(local_path.resolve().__fspath__(), quiet=True)
|
||||
assert not caplog.text
|
||||
|
||||
|
||||
def test_tag_generation_debug_log(caplog: pytest.LogCaptureFixture):
|
||||
with caplog.at_level("DEBUG"):
|
||||
make_tag(HF_INTERNAL_T5_TESTING)
|
||||
assert (
|
||||
"The full tag to be saved under model store: 'pt-hf-internal-testing-tiny-random-t5:2f582cd79ed5795b71539951d237945bc1c5ac7e'"
|
||||
in caplog.text
|
||||
)
|
||||
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"configuration": {
|
||||
"format_outputs": false,
|
||||
"generation_config": {
|
||||
"diversity_penalty": 0.0,
|
||||
"early_stopping": false,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"encoder_repetition_penalty": 1.0,
|
||||
"epsilon_cutoff": 0.0,
|
||||
"eta_cutoff": 0.0,
|
||||
"length_penalty": 1.0,
|
||||
"max_new_tokens": 20,
|
||||
"min_length": 0,
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_beam_groups": 1,
|
||||
"num_beams": 1,
|
||||
"num_return_sequences": 1,
|
||||
"output_attentions": false,
|
||||
"output_hidden_states": false,
|
||||
"output_scores": false,
|
||||
"remove_invalid_values": false,
|
||||
"renormalize_logits": false,
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 0.75,
|
||||
"top_k": 15,
|
||||
"top_p": 1.0,
|
||||
"typical_p": 1.0,
|
||||
"use_cache": true
|
||||
}
|
||||
},
|
||||
"responses": [
|
||||
"What is Deep learning?\nDeep learning is a new way of studying the content and making an informed decision. It is the"
|
||||
]
|
||||
}
|
||||
@@ -13,45 +13,303 @@
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
import asyncio
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import typing as t
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
import attr
|
||||
import docker
|
||||
import docker.errors
|
||||
import docker.types
|
||||
import orjson
|
||||
import pytest
|
||||
from syrupy.extensions.json import JSONSnapshotExtension
|
||||
|
||||
import openllm
|
||||
from openllm._llm import normalise_model_name
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm.models.auto.factory import _BaseAutoLLMClass
|
||||
import subprocess
|
||||
|
||||
_FRAMEWORK_MAPPING = {"flan_t5": "google/flan-t5-small", "opt": "facebook/opt-125m"}
|
||||
_PROMPT_MAPPING = {
|
||||
"qa": "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?",
|
||||
"default": "What is the weather in SF?",
|
||||
}
|
||||
from openllm_client.runtimes.base import BaseAsyncClient
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from syrupy.types import PropertyFilter
|
||||
from syrupy.types import PropertyMatcher
|
||||
from syrupy.types import SerializableData
|
||||
from syrupy.types import SerializedData
|
||||
|
||||
from openllm._configuration import GenerationConfig
|
||||
from openllm._types import DictStrAny
|
||||
from openllm._types import ListAny
|
||||
|
||||
else:
|
||||
DictStrAny = dict
|
||||
ListAny = list
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
|
||||
models, fname = t.cast(types.ModuleType, metafunc.module).__name__.partition(".")[-1].split(".")[1:]
|
||||
class ResponseComparator(JSONSnapshotExtension):
|
||||
def serialize(
|
||||
self,
|
||||
data: SerializableData,
|
||||
*,
|
||||
exclude: PropertyFilter | None = None,
|
||||
matcher: PropertyMatcher | None = None,
|
||||
) -> SerializedData:
|
||||
if openllm.utils.LazyType(ListAny).isinstance(data):
|
||||
data = [d.unmarshaled for d in data]
|
||||
else:
|
||||
data = data.unmarshaled
|
||||
data = self._filter(data=data, depth=0, path=(), exclude=exclude, matcher=matcher)
|
||||
return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()
|
||||
|
||||
if "tf" in fname:
|
||||
framework = "tf"
|
||||
elif "flax" in fname:
|
||||
framework = "flax"
|
||||
def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool:
|
||||
def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]:
|
||||
try:
|
||||
data = orjson.loads(data)
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise ValueError(f"Failed to decode JSON data: {data}") from err
|
||||
if openllm.utils.LazyType(DictStrAny).isinstance(data):
|
||||
return openllm.GenerationOutput(**data)
|
||||
elif openllm.utils.LazyType(ListAny).isinstance(data):
|
||||
return [openllm.GenerationOutput(**d) for d in data]
|
||||
else:
|
||||
raise NotImplementedError(f"Data {data} has unsupported type.")
|
||||
|
||||
serialized_data = convert_data(serialized_data)
|
||||
snapshot_data = convert_data(snapshot_data)
|
||||
|
||||
if openllm.utils.LazyType(ListAny).isinstance(serialized_data):
|
||||
serialized_data = [serialized_data]
|
||||
if openllm.utils.LazyType(ListAny).isinstance(snapshot_data):
|
||||
snapshot_data = [snapshot_data]
|
||||
|
||||
def eq_config(s: GenerationConfig, t: GenerationConfig) -> bool:
|
||||
return s == t
|
||||
|
||||
def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool:
|
||||
return (
|
||||
len(s.responses) == len(t.responses)
|
||||
and all([_s == _t for _s, _t in zip(s.responses, t.responses)])
|
||||
and eq_config(s.marshaled_config, t.marshaled_config)
|
||||
)
|
||||
|
||||
return len(serialized_data) == len(snapshot_data) and all(
|
||||
[eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def response_snapshot(snapshot: SnapshotAssertion):
|
||||
return snapshot.use_extension(ResponseComparator)
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class _Handle(ABC):
|
||||
port: int
|
||||
deployment_mode: t.Literal["container", "local"]
|
||||
|
||||
client: BaseAsyncClient[t.Any] = attr.field(init=False)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
def __attrs_init__(self, *args: t.Any, **attrs: t.Any):
|
||||
...
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.client = openllm.client.AsyncHTTPClient(f"http://localhost:{self.port}")
|
||||
|
||||
@abstractmethod
|
||||
def status(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def health(self, timeout: int = 240):
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
if not self.status():
|
||||
raise RuntimeError(f"Failed to initialise {self.__class__.__name__}")
|
||||
await self.client.health()
|
||||
try:
|
||||
await self.client.query("sanity")
|
||||
return
|
||||
except Exception:
|
||||
time.sleep(1)
|
||||
raise RuntimeError(f"Handle failed to initialise within {timeout} seconds.")
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class LocalHandle(_Handle):
|
||||
process: subprocess.Popen[bytes]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process: subprocess.Popen[bytes],
|
||||
port: int,
|
||||
deployment_mode: t.Literal["container", "local"],
|
||||
):
|
||||
self.__attrs_init__(port, deployment_mode, process)
|
||||
|
||||
def status(self) -> bool:
|
||||
return self.process.poll() is None
|
||||
|
||||
|
||||
class HandleProtocol(t.Protocol):
|
||||
@contextlib.contextmanager
|
||||
def __call__(
|
||||
*,
|
||||
model: str,
|
||||
model_id: str,
|
||||
image_tag: str,
|
||||
quantize: t.AnyStr | None = None,
|
||||
) -> t.Generator[_Handle, None, None]:
|
||||
...
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class DockerHandle(_Handle):
|
||||
container_name: str
|
||||
docker_client: docker.DockerClient
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
docker_client: docker.DockerClient,
|
||||
container_name: str,
|
||||
port: int,
|
||||
deployment_mode: t.Literal["container", "local"],
|
||||
):
|
||||
self.__attrs_init__(port, deployment_mode, container_name, docker_client)
|
||||
|
||||
def status(self) -> bool:
|
||||
container = self.docker_client.containers.get(self.container_name)
|
||||
return container.status in ["running", "created"]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _local_handle(
|
||||
model: str,
|
||||
model_id: str,
|
||||
image_tag: str,
|
||||
deployment_mode: t.Literal["container", "local"],
|
||||
quantize: t.Literal["int8", "int4", "gptq"] | None = None,
|
||||
*,
|
||||
_serve_grpc: bool = False,
|
||||
):
|
||||
with openllm.utils.reserve_free_port() as port:
|
||||
pass
|
||||
|
||||
if not _serve_grpc:
|
||||
proc = openllm.start(
|
||||
model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True
|
||||
)
|
||||
else:
|
||||
framework = "pt"
|
||||
proc = openllm.start_grpc(
|
||||
model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True
|
||||
)
|
||||
|
||||
llm, runner_kwargs = t.cast(
|
||||
"_BaseAutoLLMClass",
|
||||
openllm[framework], # type: ignore
|
||||
).for_model(models, model_id=_FRAMEWORK_MAPPING[models], return_runner_kwargs=True, ensure_available=True)
|
||||
llm.ensure_model_id_exists()
|
||||
if "runner" in metafunc.function.__name__:
|
||||
llm = llm.to_runner(**runner_kwargs)
|
||||
llm.init_local(quiet=True)
|
||||
yield LocalHandle(proc, port, deployment_mode)
|
||||
proc.terminate()
|
||||
proc.wait(60)
|
||||
|
||||
if "qa" in metafunc.fixturenames:
|
||||
metafunc.parametrize("prompt,llm,qa", [(_PROMPT_MAPPING["qa"], llm, True)])
|
||||
process_output = proc.stdout.read()
|
||||
print(process_output, file=sys.stderr)
|
||||
|
||||
proc.stdout.close()
|
||||
if proc.stderr:
|
||||
proc.stderr.close()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _container_handle(
|
||||
model: str,
|
||||
model_id: str,
|
||||
image_tag: str,
|
||||
deployment_mode: t.Literal["container", "local"],
|
||||
quantize: t.Literal["int8", "int4", "gptq"] | None = None,
|
||||
*,
|
||||
_serve_grpc: bool = False,
|
||||
):
|
||||
envvar = openllm.utils.EnvVarMixin(model)
|
||||
|
||||
with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port:
|
||||
pass
|
||||
container_name = f"openllm-{model}-{normalise_model_name(model_id)}".replace("-", "_")
|
||||
client = docker.from_env()
|
||||
try:
|
||||
container = client.containers.get(container_name)
|
||||
container.stop()
|
||||
container.wait()
|
||||
container.remove()
|
||||
except docker.errors.NotFound:
|
||||
pass
|
||||
|
||||
args = ["serve" if not _serve_grpc else "serve-grpc"]
|
||||
|
||||
env: DictStrAny = {}
|
||||
|
||||
if quantize is not None:
|
||||
env[envvar.quantize] = quantize
|
||||
|
||||
available = openllm.utils.gpu_count()
|
||||
gpus = len(available) if len(available) > 0 else -1
|
||||
devs = [docker.types.DeviceRequest(count=gpus, capabilities=[["gpu"]])] if gpus > 0 else None
|
||||
|
||||
container = client.containers.run(
|
||||
image_tag,
|
||||
command=args,
|
||||
name=container_name,
|
||||
environment=env,
|
||||
auto_remove=False,
|
||||
detach=True,
|
||||
device_requests=devs,
|
||||
ports={"3000/tcp": port, "3001/tcp": prom_port},
|
||||
)
|
||||
|
||||
yield DockerHandle(client, container.name, port, deployment_mode)
|
||||
|
||||
try:
|
||||
container.stop()
|
||||
container.wait()
|
||||
except docker.errors.NotFound:
|
||||
pass
|
||||
|
||||
container_output = container.logs().decode("utf-8")
|
||||
print(container_output, file=sys.stderr)
|
||||
|
||||
container.remove()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def clean_context() -> t.Generator[contextlib.ExitStack, None, None]:
|
||||
stack = contextlib.ExitStack()
|
||||
yield stack
|
||||
stack.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def el() -> t.Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
loop = asyncio.get_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(params=["container", "local"], scope="session")
|
||||
def deployment_mode(request: pytest.FixtureRequest) -> str:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal["container", "local"]):
|
||||
if deployment_mode == "container":
|
||||
return functools.partial(_container_handle, deployment_mode=deployment_mode)
|
||||
elif deployment_mode == "local":
|
||||
return functools.partial(_local_handle, deployment_mode=deployment_mode)
|
||||
else:
|
||||
metafunc.parametrize("prompt,llm", [(_PROMPT_MAPPING["default"], llm)])
|
||||
raise ValueError(f"Unknown deployment mode: {deployment_mode}")
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -1,27 +0,0 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
|
||||
def test_small_flan(prompt: str, llm: openllm.LLM[t.Any, t.Any], qa: bool):
|
||||
assert llm(prompt)
|
||||
|
||||
|
||||
def test_small_runner_flan(prompt: str, llm: openllm.LLMRunner, qa: bool):
|
||||
assert llm(prompt)
|
||||
@@ -1,29 +0,0 @@
|
||||
# Copyright 2023 BentoML Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
|
||||
import openllm
|
||||
|
||||
|
||||
@openllm.tests.require_tf
|
||||
def test_small_tf_flan(prompt: str, llm: openllm.LLM[t.Any, t.Any], qa: bool):
|
||||
assert llm(prompt)
|
||||
|
||||
|
||||
@openllm.tests.require_tf
|
||||
def test_small_tf_runner_flan(prompt: str, llm: openllm.LLMRunner, qa: bool):
|
||||
assert llm(prompt)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user