From c7f4dc7bb21c323dfaf9c446fc74bec6db6bd8b1 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Mon, 10 Jul 2023 17:23:19 -0400 Subject: [PATCH] feat(test): snapshot testing (#107) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .gitattributes | 2 + .github/actions/setup-repo/action.yml | 11 - .github/workflows/ci.yml | 17 +- .github/workflows/cleanup-cache.yml | 46 - .pre-commit-config.yaml | 22 +- DEVELOPMENT.md | 15 +- README.md | 3 + changelog.d/107.fix.md | 26 + examples/langchain-chains-demo/service.py | 11 +- hatch.toml | 27 +- pyproject.toml | 131 ++- src/openllm/__init__.py | 27 +- src/openllm/__main__.py | 3 +- src/openllm/_configuration.py | 527 ++++----- src/openllm/_generation.py | 11 +- src/openllm/_llm.py | 1022 +++++++++-------- src/openllm/_package.py | 99 +- src/openllm/_prompt.py | 5 +- src/openllm/_quantisation.py | 89 ++ src/openllm/_schema.py | 15 +- src/openllm/_service.py | 14 +- src/openllm/_strategies.py | 86 +- src/openllm/_types.py | 38 +- src/openllm/cli.py | 885 ++++++++++---- src/openllm/client.py | 4 +- src/openllm/exceptions.py | 4 +- src/openllm/models/auto/__init__.py | 6 +- src/openllm/models/auto/configuration_auto.py | 12 +- src/openllm/models/auto/factory.py | 29 +- src/openllm/models/auto/modeling_auto.py | 5 +- src/openllm/models/auto/modeling_flax_auto.py | 5 +- src/openllm/models/auto/modeling_tf_auto.py | 5 +- src/openllm/models/chatglm/__init__.py | 5 +- .../models/chatglm/configuration_chatglm.py | 4 +- .../models/chatglm/modeling_chatglm.py | 14 +- src/openllm/models/dolly_v2/__init__.py | 5 +- .../models/dolly_v2/configuration_dolly_v2.py | 18 +- .../models/dolly_v2/modeling_dolly_v2.py | 26 +- src/openllm/models/falcon/__init__.py | 4 +- .../models/falcon/configuration_falcon.py | 6 +- src/openllm/models/falcon/modeling_falcon.py | 8 +- src/openllm/models/flan_t5/__init__.py | 6 +- .../models/flan_t5/configuration_flan_t5.py | 5 +- .../models/flan_t5/modeling_flan_t5.py | 7 +- .../models/flan_t5/modeling_flax_flan_t5.py | 7 +- .../models/flan_t5/modeling_tf_flan_t5.py | 20 +- src/openllm/models/gpt_neox/__init__.py | 4 +- .../models/gpt_neox/configuration_gpt_neox.py | 7 +- .../models/gpt_neox/modeling_gpt_neox.py | 10 +- src/openllm/models/mpt/__init__.py | 4 +- src/openllm/models/mpt/configuration_mpt.py | 8 +- src/openllm/models/mpt/modeling_mpt.py | 20 +- src/openllm/models/opt/__init__.py | 6 +- src/openllm/models/opt/configuration_opt.py | 4 +- src/openllm/models/opt/modeling_flax_opt.py | 18 +- src/openllm/models/opt/modeling_opt.py | 20 +- src/openllm/models/opt/modeling_tf_opt.py | 18 +- src/openllm/models/stablelm/__init__.py | 4 +- .../models/stablelm/configuration_stablelm.py | 7 +- .../models/stablelm/modeling_stablelm.py | 3 +- src/openllm/models/starcoder/__init__.py | 4 +- .../starcoder/configuration_starcoder.py | 3 +- .../models/starcoder/modeling_starcoder.py | 15 +- src/openllm/playground/falcon_tuned.py | 1 - src/openllm/playground/features.py | 1 - src/openllm/playground/opt_tuned.py | 1 - src/openllm/serialisation/__init__.py | 22 +- src/openllm/serialisation/ggml.py | 30 +- src/openllm/serialisation/transformers.py | 82 +- src/openllm/testing.py | 113 ++ src/openllm/tests.py | 39 - src/openllm/utils/__init__.py | 258 ++++- src/openllm/utils/analytics.py | 4 +- src/openllm/utils/codegen.py | 139 ++- src/openllm/utils/dantic.py | 138 ++- src/openllm/utils/dummy_flax_objects.py | 1 - .../utils/dummy_pt_and_cpm_kernels_objects.py | 1 - .../utils/dummy_pt_and_einops_objects.py | 1 - .../utils/dummy_pt_and_triton_objects.py | 1 - src/openllm/utils/dummy_pt_objects.py | 1 - src/openllm/utils/dummy_tf_objects.py | 1 - src/openllm/utils/import_utils.py | 41 +- src/openllm/utils/lazy.py | 38 +- src/openllm/utils/representation.py | 17 +- src/openllm_client/__init__.py | 5 +- src/openllm_client/_prompt.py | 5 +- src/openllm_client/runtimes/__init__.py | 3 + src/openllm_client/runtimes/base.py | 84 +- src/openllm_client/runtimes/grpc.py | 26 +- src/openllm_client/runtimes/http.py | 30 +- tests/_strategies/_configuration.py | 5 +- tests/{test_client.py => client_test.py} | 0 ...configuration.py => configuration_test.py} | 35 +- tests/conftest.py | 46 + tests/{test_llm.py => llm_test.py} | 32 +- .../opt_test/test_opt_125m[container].json | 34 + tests/models/conftest.py | 310 ++++- tests/models/flan_t5/__init__.py | 13 - tests/models/flan_t5/test_modeling_flan_t5.py | 27 - .../flan_t5/test_modeling_tf_flan_t5.py | 29 - tests/models/flan_t5_test.py | 60 + tests/models/opt/__init__.py | 13 - tests/models/opt/test_modeling_flax_opt.py | 29 - tests/models/opt/test_modeling_opt.py | 27 - tests/models/opt/test_modeling_tf_opt.py | 29 - tests/models/opt_test.py | 59 + ...odeling_flax_flan_t5.py => models_test.py} | 15 +- tests/{test_package.py => package_test.py} | 15 +- ...{test_strategies.py => strategies_test.py} | 3 +- tools/assert-model-table-latest | 5 +- typings/IPython/__init__.pyi | 90 ++ typings/IPython/core/__init__.pyi | 0 typings/IPython/core/getipython.pyi | 5 + typings/IPython/terminal/__init__.pyi | 3 + typings/IPython/terminal/debugger.pyi | 41 + typings/IPython/terminal/embed.pyi | 165 +++ typings/IPython/terminal/interactiveshell.pyi | 125 ++ typings/IPython/terminal/ipapp.pyi | 70 ++ typings/IPython/terminal/magics.pyi | 119 ++ typings/IPython/terminal/prompts.pyi | 27 + .../terminal/pt_inputhooks/__init__.pyi | 23 + typings/IPython/terminal/ptutils.pyi | 29 + .../IPython/terminal/shortcuts/__init__.pyi | 149 +++ .../IPython/terminal/shortcuts/auto_match.pyi | 65 ++ .../terminal/shortcuts/auto_suggest.pyi | 101 ++ .../IPython/terminal/shortcuts/filters.pyi | 81 ++ typings/attr/__init__.pyi | 144 ++- typings/attr/_typing_compat.pyi | 7 +- typings/click_option_group/_core.pyi | 114 +- typings/click_option_group/_decorators.pyi | 105 +- typings/click_option_group/_helpers.pyi | 25 - typings/click_option_group/_version.pyi | 4 +- typings/deepmerge/strategy/dict.pyi | 2 +- typings/deepmerge/strategy/list.pyi | 2 +- typings/deepmerge/strategy/set.pyi | 2 +- typings/docker/__init__.pyi | 2 + typings/docker/api/__init__.pyi | 1 + typings/docker/api/build.pyi | 150 +++ typings/docker/api/client.pyi | 96 ++ typings/docker/api/config.pyi | 61 + typings/docker/api/container.pyi | 962 ++++++++++++++++ typings/docker/api/daemon.pyi | 115 ++ typings/docker/api/exec_api.pyi | 100 ++ typings/docker/api/image.pyi | 336 ++++++ typings/docker/api/network.pyi | 174 +++ typings/docker/api/plugin.pyi | 160 +++ typings/docker/api/secret.pyi | 60 + typings/docker/api/service.pyi | 217 ++++ typings/docker/api/swarm.pyi | 318 +++++ typings/docker/api/volume.pyi | 112 ++ typings/docker/client.pyi | 83 ++ typings/docker/constants.pyi | 19 + typings/docker/context/__init__.pyi | 2 + typings/docker/context/api.pyi | 129 +++ typings/docker/context/config.pyi | 12 + typings/docker/context/context.pyi | 29 + typings/docker/credentials/__init__.pyi | 7 + typings/docker/credentials/constants.pyi | 4 + typings/docker/credentials/errors.pyi | 7 + typings/docker/credentials/store.pyi | 27 + typings/docker/credentials/utils.pyi | 5 + typings/docker/errors.pyi | 62 + typings/docker/models/__init__.pyi | 1 + typings/docker/models/configs.pyi | 56 + typings/docker/models/containers.pyi | 832 ++++++++++++++ typings/docker/models/images.pyi | 329 ++++++ typings/docker/models/networks.pyi | 175 +++ typings/docker/models/nodes.pyi | 95 ++ typings/docker/models/plugins.pyi | 152 +++ typings/docker/models/resource.pyi | 38 + typings/docker/models/secrets.pyi | 56 + typings/docker/models/services.pyi | 227 ++++ typings/docker/models/swarm.pyi | 143 +++ typings/docker/models/volumes.pyi | 85 ++ typings/docker/transport/__init__.pyi | 5 + typings/docker/transport/basehttpadapter.pyi | 6 + typings/docker/transport/npipeconn.pyi | 20 + typings/docker/transport/npipesocket.pyi | 66 ++ typings/docker/transport/sshconn.pyi | 32 + typings/docker/transport/ssladapter.pyi | 26 + typings/docker/transport/unixconn.pyi | 20 + typings/docker/types/__init__.pyi | 30 + typings/docker/types/base.pyi | 4 + typings/docker/types/containers.pyi | 238 ++++ typings/docker/types/daemon.pyi | 21 + typings/docker/types/healthcheck.pyi | 51 + typings/docker/types/networks.pyi | 66 ++ typings/docker/types/services.pyi | 427 +++++++ typings/docker/types/swarm.pyi | 48 + typings/docker/utils/__init__.pyi | 0 typings/docker/utils/build.pyi | 31 + typings/docker/utils/config.pyi | 15 + typings/docker/utils/decorators.pyi | 5 + typings/docker/utils/fnmatch.pyi | 47 + typings/docker/utils/json_stream.pyi | 34 + typings/docker/utils/proxy.pyi | 32 + typings/docker/utils/socket.pyi | 67 ++ typings/docker/utils/utils.pyi | 48 + typings/nbformat/__init__.pyi | 1 + typings/nbformat/v4/__init__.pyi | 35 + typings/nbformat/v4/convert.pyi | 95 ++ typings/nbformat/v4/nbbase.pyi | 52 + typings/nbformat/v4/nbjson.pyi | 45 + typings/nbformat/v4/rwbase.pyi | 56 + typings/rsmiBindings.pyi | 61 +- 205 files changed, 11633 insertions(+), 2349 deletions(-) delete mode 100644 .github/workflows/cleanup-cache.yml create mode 100644 changelog.d/107.fix.md create mode 100644 src/openllm/testing.py delete mode 100644 src/openllm/tests.py rename tests/{test_client.py => client_test.py} (100%) rename tests/{test_configuration.py => configuration_test.py} (88%) rename tests/{test_llm.py => llm_test.py} (60%) create mode 100644 tests/models/__snapshots__/opt_test/test_opt_125m[container].json delete mode 100644 tests/models/flan_t5/__init__.py delete mode 100644 tests/models/flan_t5/test_modeling_flan_t5.py delete mode 100644 tests/models/flan_t5/test_modeling_tf_flan_t5.py create mode 100644 tests/models/flan_t5_test.py delete mode 100644 tests/models/opt/__init__.py delete mode 100644 tests/models/opt/test_modeling_flax_opt.py delete mode 100644 tests/models/opt/test_modeling_opt.py delete mode 100644 tests/models/opt/test_modeling_tf_opt.py create mode 100644 tests/models/opt_test.py rename tests/{models/flan_t5/test_modeling_flax_flan_t5.py => models_test.py} (70%) rename tests/{test_package.py => package_test.py} (86%) rename tests/{test_strategies.py => strategies_test.py} (98%) create mode 100644 typings/IPython/__init__.pyi create mode 100644 typings/IPython/core/__init__.pyi create mode 100644 typings/IPython/core/getipython.pyi create mode 100644 typings/IPython/terminal/__init__.pyi create mode 100644 typings/IPython/terminal/debugger.pyi create mode 100644 typings/IPython/terminal/embed.pyi create mode 100644 typings/IPython/terminal/interactiveshell.pyi create mode 100644 typings/IPython/terminal/ipapp.pyi create mode 100644 typings/IPython/terminal/magics.pyi create mode 100644 typings/IPython/terminal/prompts.pyi create mode 100644 typings/IPython/terminal/pt_inputhooks/__init__.pyi create mode 100644 typings/IPython/terminal/ptutils.pyi create mode 100644 typings/IPython/terminal/shortcuts/__init__.pyi create mode 100644 typings/IPython/terminal/shortcuts/auto_match.pyi create mode 100644 typings/IPython/terminal/shortcuts/auto_suggest.pyi create mode 100644 typings/IPython/terminal/shortcuts/filters.pyi delete mode 100644 typings/click_option_group/_helpers.pyi create mode 100644 typings/docker/__init__.pyi create mode 100644 typings/docker/api/__init__.pyi create mode 100644 typings/docker/api/build.pyi create mode 100644 typings/docker/api/client.pyi create mode 100644 typings/docker/api/config.pyi create mode 100644 typings/docker/api/container.pyi create mode 100644 typings/docker/api/daemon.pyi create mode 100644 typings/docker/api/exec_api.pyi create mode 100644 typings/docker/api/image.pyi create mode 100644 typings/docker/api/network.pyi create mode 100644 typings/docker/api/plugin.pyi create mode 100644 typings/docker/api/secret.pyi create mode 100644 typings/docker/api/service.pyi create mode 100644 typings/docker/api/swarm.pyi create mode 100644 typings/docker/api/volume.pyi create mode 100644 typings/docker/client.pyi create mode 100644 typings/docker/constants.pyi create mode 100644 typings/docker/context/__init__.pyi create mode 100644 typings/docker/context/api.pyi create mode 100644 typings/docker/context/config.pyi create mode 100644 typings/docker/context/context.pyi create mode 100644 typings/docker/credentials/__init__.pyi create mode 100644 typings/docker/credentials/constants.pyi create mode 100644 typings/docker/credentials/errors.pyi create mode 100644 typings/docker/credentials/store.pyi create mode 100644 typings/docker/credentials/utils.pyi create mode 100644 typings/docker/errors.pyi create mode 100644 typings/docker/models/__init__.pyi create mode 100644 typings/docker/models/configs.pyi create mode 100644 typings/docker/models/containers.pyi create mode 100644 typings/docker/models/images.pyi create mode 100644 typings/docker/models/networks.pyi create mode 100644 typings/docker/models/nodes.pyi create mode 100644 typings/docker/models/plugins.pyi create mode 100644 typings/docker/models/resource.pyi create mode 100644 typings/docker/models/secrets.pyi create mode 100644 typings/docker/models/services.pyi create mode 100644 typings/docker/models/swarm.pyi create mode 100644 typings/docker/models/volumes.pyi create mode 100644 typings/docker/transport/__init__.pyi create mode 100644 typings/docker/transport/basehttpadapter.pyi create mode 100644 typings/docker/transport/npipeconn.pyi create mode 100644 typings/docker/transport/npipesocket.pyi create mode 100644 typings/docker/transport/sshconn.pyi create mode 100644 typings/docker/transport/ssladapter.pyi create mode 100644 typings/docker/transport/unixconn.pyi create mode 100644 typings/docker/types/__init__.pyi create mode 100644 typings/docker/types/base.pyi create mode 100644 typings/docker/types/containers.pyi create mode 100644 typings/docker/types/daemon.pyi create mode 100644 typings/docker/types/healthcheck.pyi create mode 100644 typings/docker/types/networks.pyi create mode 100644 typings/docker/types/services.pyi create mode 100644 typings/docker/types/swarm.pyi create mode 100644 typings/docker/utils/__init__.pyi create mode 100644 typings/docker/utils/build.pyi create mode 100644 typings/docker/utils/config.pyi create mode 100644 typings/docker/utils/decorators.pyi create mode 100644 typings/docker/utils/fnmatch.pyi create mode 100644 typings/docker/utils/json_stream.pyi create mode 100644 typings/docker/utils/proxy.pyi create mode 100644 typings/docker/utils/socket.pyi create mode 100644 typings/docker/utils/utils.pyi create mode 100644 typings/nbformat/__init__.pyi create mode 100644 typings/nbformat/v4/__init__.pyi create mode 100644 typings/nbformat/v4/convert.pyi create mode 100644 typings/nbformat/v4/nbbase.pyi create mode 100644 typings/nbformat/v4/nbjson.pyi create mode 100644 typings/nbformat/v4/rwbase.pyi diff --git a/.gitattributes b/.gitattributes index 49058e6e..535c916b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,5 @@ nightly-requirements.txt linguist-generated=true nightly-requirements-gpu.txt linguist-generated=true +tests/models/__snapshots__/* linguist-generated=true +typings/**/*.pyi linguist-generated=true * text=auto eol=lf diff --git a/.github/actions/setup-repo/action.yml b/.github/actions/setup-repo/action.yml index 1afa5cc9..a4ec9e5d 100644 --- a/.github/actions/setup-repo/action.yml +++ b/.github/actions/setup-repo/action.yml @@ -31,10 +31,6 @@ runs: with: python-version: ${{ inputs.python-version }} architecture: ${{ inputs.architecture }} - - name: Setup node - uses: actions/setup-node@v3 - with: - node-version: '17' - name: Get cache key prefix id: get-cache-key-prefix shell: bash @@ -54,10 +50,3 @@ runs: - name: Install dependencies shell: bash run: pip install hatch towncrier - - name: Install pyright - shell: bash - run: npm install -g npm@^7 pyright - - name: Setup bufbuild/buf - uses: bufbuild/buf-setup-action@v1.20.0 - with: - github_token: ${{ github.token }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ed754605..3017dfc1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,6 +29,20 @@ defaults: run: shell: bash --noprofile --norc -exo pipefail {0} jobs: + quality: + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' + name: quality-check + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Setup CI + uses: ./.github/actions/setup-repo + with: + python-version: ${{ env.STABLE_PYTHON_VERSION }} + - name: Run type check + run: hatch run typing tests: runs-on: ubuntu-latest if: ${{ github.event_name == 'pull_request' || github.event_name == 'push' }} @@ -47,7 +61,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Run tests - run: hatch run full + run: hatch run tests:python - name: Disambiguate coverage filename run: mv .coverage ".coverage.${{ matrix.os }}.${{ matrix.python-version }}" - name: Upload coverage data @@ -99,6 +113,7 @@ jobs: needs: - coverage - tests + - quality runs-on: ubuntu-latest steps: - name: Decide whether the needed jobs succeeded or failed diff --git a/.github/workflows/cleanup-cache.yml b/.github/workflows/cleanup-cache.yml deleted file mode 100644 index 2e6a39ca..00000000 --- a/.github/workflows/cleanup-cache.yml +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: cache-cleanup -on: - pull_request: - types: - - closed -jobs: - cleanup: - runs-on: ubuntu-latest - if: github.repository_owner == 'bentoml' - steps: - - name: Check out code - uses: actions/checkout@v3 - - name: Cleanup - run: | - gh extension install actions/gh-actions-cache - - REPO=${{ github.repository }} - BRANCH="refs/pull/${{ github.event.pull_request.number }}/merge" - - echo "Fetching list of cache key" - cacheKeysForPR=$(gh actions-cache list -R $REPO -B $BRANCH | cut -f 1 ) - - ## Setting this to not fail the workflow while deleting cache keys. - set +e - echo "Deleting caches..." - for cacheKey in $cacheKeysForPR - do - gh actions-cache delete $cacheKey -R $REPO -B $BRANCH --confirm - done - echo "Done" - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e19fe8b..a23648f0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,16 +14,16 @@ ci: autoupdate_schedule: weekly - skip: [check-models-table-update, check-models-table-update, changelog-dry-run] - autofix_commit_msg: "ci: auto fixes from pre-commit.ci\nFor more information, see https://pre-commit.ci" + skip: [check-models-table-update, check-models-table-update, changelog-dry-run, typecheck] + autofix_commit_msg: "ci: auto fixes from pre-commit.ci\n\nFor more information, see https://pre-commit.ci" autoupdate_commit_msg: 'ci: pre-commit autoupdate [pre-commit.ci]' exclude: '.*\.(css|js|svg)$' repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: 'v0.0.275' + rev: 'v0.0.277' hooks: - id: ruff - args: [--fix, --exit-non-zero-on-fix, --show-fixes] + args: [--exit-non-zero-on-fix, --show-fixes] - repo: https://github.com/psf/black rev: 23.3.0 hooks: @@ -37,6 +37,13 @@ repos: args: [--config=pyproject.toml] - repo: local hooks: + - id: typecheck + name: type-check + entry: pyright src/openllm --level error + types: [python] + language: node + pass_filenames: false + additional_dependencies: ['pyright@1.1.316'] - id: check-license-header name: check for license headers entry: ./tools/assert-license-headers @@ -69,3 +76,10 @@ repos: hooks: - id: trailing-whitespace - id: end-of-file-fixer + exclude: | + (?x)^( + tests/models/.* + )$ + - id: check-yaml + args: ['--unsafe'] + - id: check-toml diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index bba9a855..ae5d9c44 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -100,28 +100,25 @@ After setting up your environment, here's how you can start contributing: 3. Run all formatter and linter with `hatch`: ```bash - hatch run fmt + hatch run quality ``` 4. Write tests that verify your feature or fix (see [Writing Tests](#writing-tests) below). 5. Run all tests to ensure your changes haven't broken anything: ```bash - hatch run full + hatch run tests:python ``` - 6. Commit your changes: ```bash git commit -m "Add my feature" ``` - 7. Push your changes to your fork: ```bash git push origin feature/my-feature ``` - 8. Submit a Pull Request on GitHub. ## Using a custom fork @@ -141,7 +138,13 @@ directory and their filenames start with `test_`. Run all tests with: ```bash -hatch run full +hatch run tests:python +``` + +Run snapshot testing for model outputs: + +```bash +hatch run tests:models ``` ## Releasing a New Version diff --git a/README.md b/README.md index 991ea2c4..01fa3e0c 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ Discord
+ + PDM +

An open platform for operating large language models (LLMs) in production.
Fine-tune, serve, deploy, and monitor any LLMs with ease.

diff --git a/changelog.d/107.fix.md b/changelog.d/107.fix.md new file mode 100644 index 00000000..a70a8c52 --- /dev/null +++ b/changelog.d/107.fix.md @@ -0,0 +1,26 @@ +Fixes relative model_id handling for running LLM within the container. + +Added support for building container directly with `openllm build`. Users now +can do `openllm build --format=container`: + +```bash +openllm build flan-t5 --format=container +``` + +This is equivalent to: + +```bash +openllm build flan-t5 && bentoml containerize google-flan-t5-large-service +``` + +Added Snapshot testing and more robust edge cases for model testing + +General improvement in `openllm.LLM.import_model` where it will parse santised +parameters automatically. + +Fixes `openllm start ` to use correct `model_id`, ignoring `--model-id` +(The correct behaviour) + +Fixes `--workers-per-resource conserved` to respect `--device` + +Added initial interface for `LLM.embeddings` diff --git a/examples/langchain-chains-demo/service.py b/examples/langchain-chains-demo/service.py index e9b604a1..9eadcc41 100644 --- a/examples/langchain-chains-demo/service.py +++ b/examples/langchain-chains-demo/service.py @@ -13,9 +13,6 @@ # limitations under the License. from __future__ import annotations - -import subprocess -import sys import typing as t from langchain.chains import LLMChain @@ -36,11 +33,9 @@ class Query(BaseModel): def gen_llm(model_name: str, model_id: str | None = None) -> OpenLLM: - args = [sys.executable, "-m", "openllm", "download", model_name] - if model_id: - args += ["--model-id", model_id] - subprocess.check_output(args) - return OpenLLM(model_name=model_name, model_id=model_id, embedded=False) + lc_llm = OpenLLM(model_name=model_name, model_id=model_id, embedded=False) + lc_llm.runner.download_model() + return lc_llm llm = gen_llm("dolly-v2", model_id="databricks/dolly-v2-7b") diff --git a/hatch.toml b/hatch.toml index 742705e0..7e2ed149 100644 --- a/hatch.toml +++ b/hatch.toml @@ -8,8 +8,6 @@ dependencies = [ "tomlkit", # NOTE: Using under ./tools/update-readme.py "markdown-it-py", - # NOTE: pyright for type - "pyright", # NOTE: Tests strategies with Hypothesis and pytest, and snapshot testing with syrupy "coverage[toml]>=6.5", "filelock>=3.7.1", @@ -26,25 +24,28 @@ dependencies = [ ] features = ['flan-t5'] [envs.default.scripts] -_run_script = "pytest --cov --cov-report={env:COVERAGE_REPORT:term-missing} --cov-config=pyproject.toml" changelog = "towncrier build --version main --draft" -fmt = ["tools", "pre-commit run --all-files"] -full = "_run_script --reruns 5 --reruns-delay 3 -r aR {args:tests}" -setup = "pre-commit install" -tools = [ +quality = [ "./tools/update-readme.py", "./tools/update-optional-dependencies.py", "./tools/update-config-stubs.py", "./tools/update-models-import.py", "- ./tools/add-license-headers .", + "pre-commit run --all-files", ] -typing = "pyright {args:src/openllm tests}" -[envs.test.overrides] +setup = "pre-commit install" +typing = "pre-commit run typecheck --all-files" +[envs.tests] +extra-dependencies = [ + # NOTE: interact with docker for container tests. + "docker", +] +[envs.tests.scripts] +_run_script = "pytest --cov --cov-report={env:COVERAGE_REPORT:term-missing} --cov-config=pyproject.toml" +models = "_run_script -r aR {args:tests/models}" +python = "_run_script --reruns 5 --reruns-delay 3 --ignore tests/models -n 3 -r aR {args:tests}" +[envs.tests.overrides] env.GITHUB_ACTIONS.env-vars = "COVERAGE_REPORT=" -env.HERMETIC_TESTS.type = [{ value = "container", if = ["true"] }, "virtual"] -[envs.test.scripts] -[[envs.test.matrix]] -python = ["3.8", "3.9", "3.10", "3.11"] [envs.coverage] dependencies = ["coverage[toml]>=6.5", "lxml", "orjson"] detached = true diff --git a/pyproject.toml b/pyproject.toml index d5a65f6e..e4ff436b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,12 +72,12 @@ all = [ "openllm[falcon]", "openllm[mpt]", "openllm[starcoder]", - "openllm[ggml]", - "openllm[playground]", "openllm[fine-tune]", - "openllm[agents]", - "openllm[flan-t5]", "openllm[openai]", + "openllm[playground]", + "openllm[flan-t5]", + "openllm[agents]", + "openllm[ggml]", ] chatglm = ["cpm-kernels", "sentencepiece"] falcon = ["einops", "xformers", "safetensors"] @@ -155,7 +155,7 @@ verbose = 2 whitelist-regex = ["test_.*"] [tool.pytest.ini_options] -addopts = ["-rfEX", "-pno:warnings"] +addopts = ["-rfEX", "-pno:warnings", "--snapshot-warn-unused"] python_files = ["test_*.py", "*_test.py"] testpaths = ["tests"] @@ -183,70 +183,125 @@ line-length = 119 target-version = ["py311"] [tool.ruff] -exclude = ["tools"] +exclude = ["tools", "src/openllm/playground"] +extend-select = [ + "B", # flake8-bugbear + "I", # isort + "G", # flake8-logging-format + "D", # pydocstyle + "W", # pycodestyle + "Q", # flake8-quotes + "FA", # flake8-future-annotations + "S", # flake8-bandit + "TCH", # flake8-type-checking + "PLW", # pylint-warning + "PLR", # pylint-refactor + "PT", # flake8-pytest-style + "PYI", # flake8-pyi + "PERF", # perflint + "FLY", # flynt + "RUF", # Ruff-specific rules + "YTT", # flake8-2020 +] +fix = true ignore = [ - # Allow non-abstract empty methods in abstract base classes - "B027", - # Allow boolean positional values in function calls, like `dict.get(... True)` - "FBT003", - # Ignore checks for possible passwords - "S105", + "B027", # Allow non-abstract empty methods in abstract base classes + "FBT003", # Allow boolean positional values in function calls, like `dict.get(... True)` + "S105", # Ignore checks for possible passwords "S106", "S107", - # Ignore complexity - "C901", + "S603", # ignore subprocess.call "PLR0911", "PLR0912", "PLR0913", "PLR0915", - "E501", - "E741", + "PLR2004", # magic value to use constant + "E501", # ignore line length violation + "PYI021", # ignore docstring in stubs, as pyright will include docstring in stubs. + "D103", # Just missing docstring for magic methods. + "D102", + "D101", + "D100", + "TCH004", # don't move runtime import out, just warn about it + "RUF012", # mutable attributes to be used with ClassVar + "B905", # zip warning about strict, only applicable for 3.10+ ] line-length = 119 -target-version = "py311" +target-version = "py312" unfixable = [ - "F401", # Don't touch unused imports, just warn about it. + "F401", # Don't touch unused imports, just warn about it. + "TCH004", # Don't touch import outside of TYPE_CHECKING block ] +[tool.ruff.flake8-type-checking] +exempt-modules = ["typing", "typing_extensions", "."] +runtime-evaluated-base-classes = [ + "pydantic.BaseModel", + "openllm._configuration.LLMConfig", + "openllm._configuration.GenerationConfig", + "openllm._configuration.ModelSettings", +] +runtime-evaluated-decorators = ["attrs.define", "attrs.frozen"] + [tool.ruff.pydocstyle] convention = "google" +[tool.ruff.pycodestyle] +ignore-overlong-task-comments = true + [tool.ruff.isort] force-single-line = true known-first-party = ["openllm", "bentoml", 'transformers'] lines-after-imports = 2 +no-lines-before = ["future", "standard-library"] +relative-imports-order = "closest-to-furthest" -[tool.ruff.per-file-ignores] +[tool.ruff.flake8-quotes] +avoid-escape = false + +[tool.ruff.extend-per-file-ignores] # Tests can use magic values, assertions, and relative imports "__init__.py" = ["E402", "F401", "F403", "F811"] +"examples/**/*" = ["D"] +"src/openllm/_llm.py" = ["B010", "B009"] +"src/openllm/_strategies.py" = ["B904"] "src/openllm/_types.py" = ["E402"] -"src/openllm/playground/**/*" = ["E402", "F401"] -"tests/**/*" = ["PLR2004", "S101", "TID252"] +"src/openllm/cli.py" = ["D301", "S101"] +"src/openllm/models/**/*" = ["D106", "S101", "D104"] +"src/openllm/playground/**/*" = ["E402", "F401", "PLR", "D"] +"src/openllm/utils/dummy_*" = ["D107"] +"src/openllm/utils/import_utils.py" = [ + "PLW0603", # OK to ignore global access here + "D105", # magic docstring +] +"src/openllm_client/runtimes/*" = ["D107"] +"tests/**/*" = [ + "S101", + "TID252", + "D", # No docstring in tests + "PT011", # ignore too broad raises, as it can be use pytest.raises().match() + "S307", # Ignore eval(compile) as it is a known script execution +] +"typings/**/*" = ["D", "F", "E", "PYI002", "I001"] [tool.pyright] analysis.useLibraryCodeForTypes = true -enableTypeIgnoreComments = true -include = ["src/", "tests/", "tools/", "examples/"] +exclude = ["src/openllm/playground", "src/openllm/models/"] +include = ["src/openllm", "src/openllm_client", "tests/", "tools/", "examples/"] pythonVersion = "3.12" -reportMissingImports = "none" -reportMissingModuleSource = "warning" -reportMissingTypeStubs = "warning" +reportMissingTypeStubs = false +reportPrivateUsage = "warning" +reportUnknownArgumentType = "warning" reportUnknownMemberType = "warning" reportUnknownVariableType = "warning" -strictDictionaryInference = true -strictListInference = true -strictParameterNoneValue = true -strictSetInference = true typeCheckingMode = "strict" - [tool.coverage.run] branch = true omit = [ "src/openllm/playground/", "src/openllm/__about__.py", "src/openllm/__main__.py", - "src/openllm/tests.py", "src/openllm/utils/dummy_*.py", ] source_pkgs = ["openllm"] @@ -255,4 +310,14 @@ source_pkgs = ["openllm"] openllm = ["src/openllm", "*/openllm/src/openllm"] [tool.coverage.report] -exclude_lines = ["no cov", "if __name__ == .__main__.:", "if t.TYPE_CHECKING:", "@overload", "# pragma: no cover"] +exclude_lines = [ + "no cov", + "pragma: no cover", + "if __name__ == .__main__.:", + "if t.TYPE_CHECKING:", + 'if TYPE_CHECKING:', + 'if typing.TYPE_CHECKING:', + '@overload', + '@typing.overload', + 'raise NotImplementedError', +] diff --git a/src/openllm/__init__.py b/src/openllm/__init__.py index 7f396c0a..88c1af41 100644 --- a/src/openllm/__init__.py +++ b/src/openllm/__init__.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -OpenLLM -======= +"""OpenLLM. An open platform for operating large language models in production. Fine-tune, serve, deploy, and monitor any LLMs with ease. @@ -24,7 +22,6 @@ deploy, and monitor any LLMs with ease. * Native integration with BentoML and LangChain for custom LLM apps """ from __future__ import annotations - import logging import os import typing as t @@ -39,7 +36,6 @@ if utils.DEBUG: utils.set_debug_mode(True) utils.set_quiet_mode(False) - utils.configure_logging() logging.basicConfig(level=logging.NOTSET) else: # configuration for bitsandbytes before import @@ -64,16 +60,15 @@ else: _import_structure = { "_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable"], "_configuration": ["LLMConfig"], - "_package": ["build"], "exceptions": [], "_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput"], - "utils": [], + "utils": ["infer_auto_class"], "models": [], "client": [], "playground": [], - "tests": [], + "testing": [], "serialisation": ["ggml", "transformers"], - "cli": ["start", "start_grpc"], + "cli": ["start", "start_grpc", "build", "import_model", "list_models"], # NOTE: models "models.auto": [ "AutoConfig", @@ -182,23 +177,23 @@ if t.TYPE_CHECKING: from . import exceptions as exceptions from . import models as models from . import playground as playground - from . import tests as tests from . import serialisation as serialisation + from . import testing as testing # Specific types import from ._configuration import LLMConfig as LLMConfig from ._llm import LLM as LLM - from ._llm import LLMRunner as LLMRunner from ._llm import LLMRunnable as LLMRunnable + from ._llm import LLMRunner as LLMRunner from ._llm import Runner as Runner - from ._package import build as build from ._schema import GenerationInput as GenerationInput from ._schema import GenerationOutput as GenerationOutput from ._schema import MetadataOutput as MetadataOutput + from .cli import build as build + from .cli import import_model as import_model + from .cli import list_models as list_models from .cli import start as start from .cli import start_grpc as start_grpc - from .serialisation import ggml as ggml - from .serialisation import transformers as transformers from .models.auto import CONFIG_MAPPING as CONFIG_MAPPING from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES from .models.auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES @@ -213,6 +208,9 @@ if t.TYPE_CHECKING: from .models.opt import OPTConfig as OPTConfig from .models.stablelm import StableLMConfig as StableLMConfig from .models.starcoder import StarCoderConfig as StarCoderConfig + from .serialisation import ggml as ggml + from .serialisation import transformers as transformers + from .utils import infer_auto_class as infer_auto_class # NOTE: torch and cpm_kernels try: @@ -286,6 +284,7 @@ else: globals()["__file__"], _import_structure, module_spec=__spec__, + doc=__doc__, extra_objects={ "__version__": __version__, # The below is a special mapping that allows openllm to be used as a dictionary. diff --git a/src/openllm/__main__.py b/src/openllm/__main__.py index d42398a0..96c629ec 100644 --- a/src/openllm/__main__.py +++ b/src/openllm/__main__.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -CLI entrypoint for OpenLLM. +"""CLI entrypoint for OpenLLM. Usage: openllm --help diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index fc22c8a3..e23975e6 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Configuration utilities for OpenLLM. All model configuration will inherit from ``openllm.LLMConfig``. +"""Configuration utilities for OpenLLM. All model configuration will inherit from ``openllm.LLMConfig``. Highlight feature: Each fields in ``openllm.LLMConfig`` will also automatically generate a environment variable based on its name field. @@ -47,7 +46,6 @@ dynamically during serve, ahead-of-serve or per requests. Refer to ``openllm.LLMConfig`` docstring for more information. """ from __future__ import annotations - import copy import enum import logging @@ -68,56 +66,48 @@ from deepmerge.merger import Merger import openllm from .exceptions import ForbiddenAttributeError +from .utils import ENV_VARS_TRUE_VALUES from .utils import LazyType -from .utils import ReprMixin, ENV_VARS_TRUE_VALUES +from .utils import ReprMixin from .utils import bentoml_cattr from .utils import codegen from .utils import dantic -from .utils import first_not_none, field_env_key +from .utils import field_env_key +from .utils import first_not_none from .utils import lenient_issubclass from .utils import non_intrusive_setattr from .utils import requires_dependencies -if hasattr(t, "Required"): - from typing import Required -else: - from typing_extensions import Required - -if hasattr(t, "NotRequired"): - from typing import NotRequired -else: - from typing_extensions import NotRequired - -if hasattr(t, "dataclass_transform"): - from typing import dataclass_transform -else: - from typing_extensions import dataclass_transform - -# NOTE: We need to do this so that overload can register +# NOTE: We need to do check overload import +# so that it can register # correct overloads to typing registry -if hasattr(t, "get_overloads"): +if sys.version_info[:2] >= (3, 11): + from typing import NotRequired + from typing import Required + from typing import dataclass_transform from typing import overload else: + from typing_extensions import NotRequired + from typing_extensions import Required + from typing_extensions import dataclass_transform from typing_extensions import overload _T = t.TypeVar("_T") if t.TYPE_CHECKING: + import click import peft - from attr import _CountingAttr # type: ignore - from attr import _make_init # type: ignore - from attr import _transform_attrs # type: ignore + from attr import _CountingAttr + from attr import _make_init + from attr import _transform_attrs from attr._compat import set_closure_cell import transformers from transformers.generation.beam_constraints import Constraint - from ._types import ClickFunctionWrapper - from ._types import F - from ._types import O_co - from ._types import P + from ._types import AnyCallable DictStrAny = dict[str, t.Any] ListStr = list[str] @@ -154,10 +144,10 @@ config_merger = Merger( # case insensitive, but rename to conform with type class _PeftEnumMeta(enum.EnumMeta): - def __getitem__(self, __key: str | t.Any) -> PeftType: + def __getitem__(self, __key: str | t.Any) -> enum.Enum: if isinstance(__key, str): __key = inflection.underscore(__key).upper() - return super().__getitem__(__key) + return self._member_map_[__key] # vendorred from peft.utils.config.PeftType @@ -171,11 +161,11 @@ class PeftType(enum.Enum, metaclass=_PeftEnumMeta): ADAPTION_PROMPT = "ADAPTION_PROMPT" @classmethod - def _missing_(cls, value: object) -> PeftType | None: + def _missing_(cls, value: object) -> enum.Enum | None: if isinstance(value, str): normalized = inflection.underscore(value).upper() if normalized in cls._member_map_: - return cls[normalized] + return cls._member_map_[normalized] @classmethod def supported(cls) -> set[str]: @@ -184,6 +174,11 @@ class PeftType(enum.Enum, metaclass=_PeftEnumMeta): def to_str(self) -> str: return self.value + @staticmethod + def get(__key: str | t.Any) -> PeftType: + """type-safe getitem.""" + return t.cast(PeftType, PeftType[__key]) + _PEFT_TASK_TYPE_TARGET_MAPPING = {"causal_lm": "CAUSAL_LM", "seq2seq_lm": "SEQ_2_SEQ_LM"} @@ -200,14 +195,33 @@ def _adapter_converter(value: AdapterType | str | PeftType | None) -> PeftType: raise ValueError("'AdapterType' cannot be None.") if isinstance(value, PeftType): return value - if isinstance(value, str) and value not in PeftType.supported(): + if value not in PeftType.supported(): raise ValueError(f"Given '{value}' is not a supported adapter type.") - return PeftType[value] + return PeftType.get(value) @attr.define(slots=True) class FineTuneConfig: - """FineTuneConfig defines a default value for fine-tuning this any given LLM. For example: + """FineTuneConfig defines a default value for fine-tuning this any given LLM. + + For example: + + ```python + class FalconConfig(openllm.LLMConfig): + + __config__ = { + "fine_tune_strategies": ( + { + "adapter_type": "lora", + "r": 64, + "lora_alpha": 16, + "lora_dropout": 0.1, + "bias": "none", + "target_modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], + }, + ), + } + ``` This is a lower level API that leverage `peft` as well as openllm.LLMConfig to create default and customization @@ -318,8 +332,7 @@ class FineTuneConfig: docs: str | None = None, **attrs: t.Any, ) -> type[FineTuneConfig]: - """A loose codegen to create default subclass for given adapter config type""" - + """A loose codegen to create default subclass for given adapter config type.""" _new_default = { "adapter_type": PeftType[adapter_type], "adapter_config": attrs, @@ -355,8 +368,7 @@ class FineTuneConfig: @attr.frozen(slots=True, repr=False) class GenerationConfig(ReprMixin): - """Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``, - with some additional validation and environment constructor. + """GenerationConfig is the attrs-compatible version of ``transformers.GenerationConfig``, with some additional validation and environment constructor. Note that we always set `do_sample=True`. This class is not designed to be used directly, rather to be used conjunction with LLMConfig. The instance of the generation config can then be accessed @@ -588,7 +600,7 @@ class GenerationConfig(ReprMixin): if t.TYPE_CHECKING: - def __attrs_init__(self, **_: t.Any): + def __attrs_init__(self, *args: t.Any, **attrs: t.Any): ... def __init__(self, *, _internal: bool = False, **attrs: t.Any): @@ -628,6 +640,7 @@ _object_getattribute = object.__getattribute__ class ModelSettings(t.TypedDict, total=False): """ModelSettings serve only for typing purposes as this is transcribed into LLMConfig.__config__. + Note that all fields from this dictionary will then be converted to __openllm_*__ fields in LLMConfig. If the field below changes, make sure to run ./tools/update-config-stubs.py to generate correct __getitem__ @@ -728,9 +741,9 @@ class _ModelSettingsAttr: service_name: str requirements: t.Optional[ListStr] bettertransformer: bool - model_type: t.Literal['causal_lm', 'seq2seq_lm'] - runtime: t.Literal['transformers', 'ggml'] - name_type: t.Optional[t.Literal['dasherize', 'lowercase']] + model_type: t.Literal["causal_lm", "seq2seq_lm"] + runtime: t.Literal["transformers", "ggml"] + name_type: t.Optional[t.Literal["dasherize", "lowercase"]] model_name: str start_name: str env: openllm.utils.EnvVarMixin @@ -743,7 +756,6 @@ class _ModelSettingsAttr: def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]): - assert cl_.__config__ is not None, f"'__config__' is required for {cls}." if "generation_class" in cl_.__config__: raise ValueError( "'generation_class' shouldn't be defined in '__config__', rather defining " @@ -813,8 +825,8 @@ bentoml_cattr.register_structure_hook(_ModelSettingsAttr, structure_settings) def _setattr_class(attr_name: str, value_var: t.Any): - """ - Use the builtin setattr to set *attr_name* to *value_var*. + """Use the builtin setattr to set *attr_name* to *value_var*. + We can't use the cached object.__setattr__ since we are setting attributes to a class. @@ -828,7 +840,7 @@ def _setattr_class(attr_name: str, value_var: t.Any): def _make_assignment_script( cls: type[LLMConfig], attributes: attr.AttrsInstance, _prefix: t.LiteralString = "openllm" ) -> t.Callable[..., None]: - """Generate the assignment script with prefix attributes __openllm___""" + """Generate the assignment script with prefix attributes __openllm___.""" args: ListStr = [] globs: DictStrAny = { "cls": cls, @@ -852,13 +864,14 @@ def _make_assignment_script( _reserved_namespace = {"__config__", "GenerationConfig"} -@dataclass_transform(order_default=True, field_specifiers=(attr.field, dantic.Field)) +@dataclass_transform(kw_only_default=True, order_default=True, field_specifiers=(attr.field, dantic.Field)) def llm_config_transform(cls: type[LLMConfig]) -> type[LLMConfig]: non_intrusive_setattr( cls, "__dataclass_transform__", { "order_default": True, + "kw_only_default": True, "field_specifiers": (attr.field, dantic.Field), }, ) @@ -880,7 +893,7 @@ class _ConfigAttr: # NOTE: The following is handled via __init_subclass__, and is only used for TYPE_CHECKING if t.TYPE_CHECKING: # NOTE: public attributes to override - __config__: ModelSettings | None = Field(None) + __config__: ModelSettings = Field(None) """Internal configuration for this LLM model. Each of the field in here will be populated and prefixed with __openllm___""" GenerationConfig: type = Field(None) @@ -914,7 +927,7 @@ class _ConfigAttr: to create the generation_config argument that can be used throughout the lifecycle. This class will also be managed internally by OpenLLM.""" - def __attrs_init__(self, **attrs: t.Any): + def __attrs_init__(self, *args: t.Any, **attrs: t.Any): """Generated __attrs_init__ for LLMConfig subclass that follows the attrs contract.""" # NOTE: The following will be populated from __config__ and also @@ -955,14 +968,14 @@ class _ConfigAttr: architecture. By default, we will use BetterTransformer for T5 and StableLM models, and set to False for every other models. """ - __openllm_model_type__: t.Literal['causal_lm', 'seq2seq_lm'] = Field(None) + __openllm_model_type__: t.Literal["causal_lm", "seq2seq_lm"] = Field(None) """The model type for this given LLM. By default, it should be causal language modeling. Currently supported 'causal_lm' or 'seq2seq_lm' """ - __openllm_runtime__: t.Literal['transformers', 'ggml'] = Field(None) + __openllm_runtime__: t.Literal["transformers", "ggml"] = Field(None) """The runtime to use for this model. Possible values are `transformers` or `ggml`. See LlaMA for more information.""" - __openllm_name_type__: t.Optional[t.Literal['dasherize', 'lowercase']] = Field(None) + __openllm_name_type__: t.Optional[t.Literal["dasherize", "lowercase"]] = Field(None) """The default name typed for this model. "dasherize" will convert the name to lowercase and replace spaces with dashes. "lowercase" will convert the name to lowercase. If this is not set, then both `model_name` and `start_name` must be specified.""" @@ -991,11 +1004,187 @@ class _ConfigAttr: # fmt: on -@attr.define(slots=True) -class LLMConfig(_ConfigAttr): +class _ConfigBuilder: + """A modified version of attrs internal _ClassBuilder, and should only be called within __init_subclass__ of LLMConfig. + + Where: + - has_custom_setattr=True + - getstate_setstate=None (config class will always be a slotted class.) + - slots=True + - auto_attribs=False (We should handle it before _ConfigBuilder is invoked) + - cache_hash=False (We don't need to cache the hash code of this object for now.) + - collect_by_mro=True (The correct behaviour to resolve inheritance) + - field_transformer=codegen.make_env_transformer (We need to transform the field to have env variable) + + It takes `these` arguments as a fully parsed attr.Attribute[t.Any] from __init_subclass__ """ - ``openllm.LLMConfig`` is somewhat a hybrid combination between the performance of `attrs` with the - easy-to-use interface that pydantic offer. It lives in between where it allows users to quickly formulate + + __slots__ = ( + "_cls", + "_cls_dict", + "_attr_names", + "_attrs", + "_model_name", + "_base_attr_map", + "_base_names", + "_has_pre_init", + "_has_post_init", + ) + + def __init__( + self, + cls: type[LLMConfig], + these: dict[str, _CountingAttr[t.Any]], + auto_attribs: bool = False, + kw_only: bool = False, + collect_by_mro: bool = True, + ): + attrs, base_attrs, base_attr_map = _transform_attrs( + cls, + these, + auto_attribs, + kw_only, + collect_by_mro, + field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__), + ) + + self._cls = cls + self._model_name = cls.__openllm_model_name__ + self._cls_dict = dict(cls.__dict__) + self._attrs = attrs + self._base_names = {a.name for a in base_attrs} + self._base_attr_map = base_attr_map + self._attr_names = tuple(a.name for a in attrs) + self._has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False)) + self._has_post_init = bool(getattr(cls, "__attrs_post_init__", False)) + + self._cls_dict["__attrs_attrs__"] = self._attrs + + def build_class(self) -> type[LLMConfig]: + """Finalize class based on the accumulated configuration. + + Builder cannot be used after calling this method. + + > A difference between this and attrs._ClassBuilder is that we don't + > create a new class after constructing all __dict__. This has to do + > with recursive called within __init_subclass__ + """ + cd = { + k: v for k, v in self._cls_dict.items() if k not in (*tuple(self._attr_names), "__dict__", "__weakref__") + } + # Traverse the MRO to collect existing slots + # and check for an existing __weakref__. + existing_slots: DictStrAny = {} + weakref_inherited = False + for base_cls in self._cls.__mro__[1:-1]: + if base_cls.__dict__.get("__weakref__", None) is not None: + weakref_inherited = True + existing_slots.update( + {name: getattr(base_cls, name, codegen._sentinel) for name in getattr(base_cls, "__slots__", [])} + ) + + base_names = set(self._base_names) + names = self._attr_names + if ( + "__weakref__" not in getattr(self._cls, "__slots__", ()) + and "__weakref__" not in names + and not weakref_inherited + ): + names += ("__weakref__",) + + # We only add the names of attributes that aren't inherited. + # Setting __slots__ to inherited attributes wastes memory. + slot_names = [name for name in names if name not in base_names] + # There are slots for attributes from current class + # that are defined in parent classes. + # As their descriptors may be overridden by a child class, + # we collect them here and update the class dict + reused_slots = { + slot: slot_descriptor for slot, slot_descriptor in existing_slots.items() if slot in slot_names + } + # We only add the names of attributes that aren't inherited. + # Setting __slots__ to inherited attributes wastes memory. + # __openllm_extras__ holds additional metadata that might be usefule for users, hence we add it to slots + slot_names = [name for name in slot_names if name not in reused_slots] + cd.update(reused_slots) + cd["__slots__"] = tuple(slot_names) + + cd["__qualname__"] = self._cls.__qualname__ + + # We can only patch the class here, rather than instantiate + # a new one, since type.__new__ actually will invoke __init_subclass__ + # and since we use the _ConfigBuilder in __init_subclass__, it will + # raise recusion error. See https://peps.python.org/pep-0487/ for more + # information on how __init_subclass__ works. + for k, value in cd.items(): + setattr(self._cls, k, value) + + return self.make_closure(self._cls) + + def make_closure(self, cls: type): + # The following is a fix for + # . + # If a method mentions `__class__` or uses the no-arg super(), the + # compiler will bake a reference to the class in the method itself + # as `method.__closure__`. Since we replace the class with a + # clone, we rewrite these references so it keeps working. + for item in cls.__dict__.values(): + if isinstance(item, (classmethod, staticmethod)): + # Class- and staticmethods hide their functions inside. + # These might need to be rewritten as well. + closure_cells = getattr(item.__func__, "__closure__", None) + elif isinstance(item, property): + # Workaround for property `super()` shortcut (PY3-only). + # There is no universal way for other descriptors. + closure_cells = getattr(item.fget, "__closure__", None) + else: + closure_cells = getattr(item, "__closure__", None) + + if not closure_cells: # Catch None or the empty list. + continue + for cell in closure_cells: + try: + match = cell.cell_contents is self._cls + except ValueError: # ValueError: Cell is empty + pass + else: + if match: + set_closure_cell(cell, cls) + + return llm_config_transform(cls) + + def add_attrs_init(self) -> t.Self: + self._cls_dict["__attrs_init__"] = codegen.add_method_dunders( + self._cls, + _make_init( + self._cls, + self._attrs, + self._has_pre_init, + self._has_post_init, + False, # frozen + True, # slots + False, # cache_hash + self._base_attr_map, + False, # This is not an exception + None, # no on_setattr + True, + ), + ) + return self + + def add_repr(self): + for key, fn in ReprMixin.__dict__.items(): + if key in ("__repr__", "__str__", "__repr_name__", "__repr_str__", "__repr_args__"): + self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn) + self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs} | {"generation_config"}) + return self + + +@attr.define(slots=True, init=False) +class LLMConfig(_ConfigAttr): + """``openllm.LLMConfig`` is a pydantic-like ``attrs`` interface that offers fast and easy-to-use APIs. + + It lives in between the nice UX of `pydantic` and fast performance of `attrs` where it allows users to quickly formulate a LLMConfig for any LLM without worrying too much about performance. It does a few things: - Automatic environment conversion: Each fields will automatically be provisioned with an environment @@ -1077,182 +1266,16 @@ class LLMConfig(_ConfigAttr): ), } ``` + + Future work: + - Support pydantic-core as validation backend. """ - class _ConfigBuilder: - """A modified version of attrs internal _ClassBuilder, should only be called - within __init_subclass__ of LLMConfig. - - Where: - - has_custom_setattr=True - - getstate_setstate=None (config class will always be a slotted class.) - - slots=True - - auto_attribs=False (We should handle it before _ConfigBuilder is invoked) - - cache_hash=False (We don't need to cache the hash code of this object for now.) - - collect_by_mro=True (The correct behaviour to resolve inheritance) - - field_transformer=codegen.make_env_transformer (We need to transform the field to have env variable) - - It takes `these` arguments as a fully parsed attr.Attribute[t.Any] from __init_subclass__ - """ - - __slots__ = ( - "_cls", - "_cls_dict", - "_attr_names", - "_attrs", - "_model_name", - "_base_attr_map", - "_base_names", - "_has_pre_init", - "_has_post_init", - ) - - def __init__( - self, - cls: type[LLMConfig], - these: dict[str, _CountingAttr[t.Any]], - auto_attribs: bool = False, - kw_only: bool = False, - collect_by_mro: bool = True, - ): - attrs, base_attrs, base_attr_map = _transform_attrs( - cls, - these, - auto_attribs, - kw_only, - collect_by_mro, - field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__), - ) - - self._cls = cls - self._model_name = cls.__openllm_model_name__ - self._cls_dict = dict(cls.__dict__) - self._attrs = attrs - self._base_names = {a.name for a in base_attrs} - self._base_attr_map = base_attr_map - self._attr_names = tuple(a.name for a in attrs) - self._has_pre_init = bool(getattr(cls, "__attrs_pre_init__", False)) - self._has_post_init = bool(getattr(cls, "__attrs_post_init__", False)) - - self._cls_dict["__attrs_attrs__"] = self._attrs - - def build_class(self) -> type[LLMConfig]: - """ - Finalize class based on the accumulated configuration. - - Builder cannot be used after calling this method. - - > A difference between this and attrs._ClassBuilder is that we don't - > create a new class after constructing all __dict__. This has to do - > with recursive called within __init_subclass__ - """ - cd = { - k: v - for k, v in self._cls_dict.items() - if k not in tuple(self._attr_names) + ("__dict__", "__weakref__") - } - # Traverse the MRO to collect existing slots - # and check for an existing __weakref__. - existing_slots: DictStrAny = {} - weakref_inherited = False - for base_cls in self._cls.__mro__[1:-1]: - if base_cls.__dict__.get("__weakref__", None) is not None: - weakref_inherited = True - existing_slots.update( - {name: getattr(base_cls, name, codegen._sentinel) for name in getattr(base_cls, "__slots__", [])} - ) - - base_names = set(self._base_names) - names = self._attr_names - if ( - "__weakref__" not in getattr(self._cls, "__slots__", ()) - and "__weakref__" not in names - and not weakref_inherited - ): - names += ("__weakref__",) - - # We only add the names of attributes that aren't inherited. - # Setting __slots__ to inherited attributes wastes memory. - slot_names = [name for name in names if name not in base_names] - # There are slots for attributes from current class - # that are defined in parent classes. - # As their descriptors may be overridden by a child class, - # we collect them here and update the class dict - reused_slots = { - slot: slot_descriptor for slot, slot_descriptor in existing_slots.items() if slot in slot_names - } - # We only add the names of attributes that aren't inherited. - # Setting __slots__ to inherited attributes wastes memory. - # __openllm_extras__ holds additional metadata that might be usefule for users, hence we add it to slots - slot_names = [name for name in slot_names if name not in reused_slots] - cd.update(reused_slots) - cd["__slots__"] = tuple(slot_names) - - for k, value in cd.items(): - setattr(self._cls, k, value) - - # The following is a fix for - # . - # If a method mentions `__class__` or uses the no-arg super(), the - # compiler will bake a reference to the class in the method itself - # as `method.__closure__`. Since we replace the class with a - # clone, we rewrite these references so it keeps working. - for item in self._cls.__dict__.values(): - if isinstance(item, (classmethod, staticmethod)): - # Class- and staticmethods hide their functions inside. - # These might need to be rewritten as well. - closure_cells = getattr(item.__func__, "__closure__", None) - elif isinstance(item, property): - # Workaround for property `super()` shortcut (PY3-only). - # There is no universal way for other descriptors. - closure_cells = getattr(item.fget, "__closure__", None) - else: - closure_cells = getattr(item, "__closure__", None) - - if not closure_cells: # Catch None or the empty list. - continue - for cell in closure_cells: - try: - match = cell.cell_contents is self._cls - except ValueError: # ValueError: Cell is empty - pass - else: - if match: - set_closure_cell(cell, self._cls) - - return llm_config_transform(self._cls) - - def add_attrs_init(self) -> t.Self: - self._cls_dict["__attrs_init__"] = codegen.add_method_dunders( - self._cls, - _make_init( - self._cls, - self._attrs, - self._has_pre_init, - self._has_post_init, - False, # frozen - True, # slots - False, # cache_hash - self._base_attr_map, - False, # This is not an exception - None, # no on_setattr - attrs_init=True, - ), - ) - return self - - def add_repr(self): - for key, fn in ReprMixin.__dict__.items(): - if key in ("__repr__", "__str__", "__repr_name__", "__repr_str__", "__repr_args__"): - self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn) - self._cls_dict["__repr_keys__"] = property(lambda _: {i.name for i in self._attrs} | {"generation_config"}) - return self - def __init_subclass__(cls: type[LLMConfig]): - """The purpose of this __init_subclass__ is that we want all subclass of LLMConfig - to adhere to the attrs contract, and have pydantic-like interface. This means we will - construct all fields and metadata and hack into how attrs use some of the 'magic' construction - to generate the fields. + """The purpose of this ``__init_subclass__`` is to offer pydantic UX while adhering to attrs contract. + + This means we will construct all fields and metadata and hack into + how attrs use some of the 'magic' construction to generate the fields. It also does a few more extra features: It also generate all __openllm_*__ config from ModelSettings (derived from __config__) to the class. @@ -1261,7 +1284,7 @@ class LLMConfig(_ConfigAttr): logger.warning("LLMConfig subclass should end with 'Config'. Updating to %sConfig", cls.__name__) cls.__name__ = f"{cls.__name__}Config" - if not hasattr(cls, "__config__") or cls.__config__ is None: + if not hasattr(cls, "__config__"): raise RuntimeError("Given LLMConfig must have '__config__' that is not None defined.") # auto assignment attributes generated from __config__ after create the new slot class. @@ -1320,7 +1343,7 @@ class LLMConfig(_ConfigAttr): a.name for a in attr.fields(cls.__openllm_generation_class__) } - cls = cls._ConfigBuilder(cls, these).add_attrs_init().add_repr().build_class() + cls = _ConfigBuilder(cls, these).add_attrs_init().add_repr().build_class() # Finally, resolve the types if getattr(cls, "__attrs_types_resolved__", None) != cls: @@ -1398,11 +1421,11 @@ class LLMConfig(_ConfigAttr): @overload def __getitem__(self, item: t.Literal["bettertransformer"] = ...) -> bool: ... @overload - def __getitem__(self, item: t.Literal["model_type"] = ...) -> t.Literal['causal_lm', 'seq2seq_lm']: ... + def __getitem__(self, item: t.Literal["model_type"] = ...) -> t.Literal["causal_lm", "seq2seq_lm"]: ... @overload - def __getitem__(self, item: t.Literal["runtime"] = ...) -> t.Literal['transformers', 'ggml']: ... + def __getitem__(self, item: t.Literal["runtime"] = ...) -> t.Literal["transformers", "ggml"]: ... @overload - def __getitem__(self, item: t.Literal["name_type"] = ...) -> t.Optional[t.Literal['dasherize', 'lowercase']]: ... + def __getitem__(self, item: t.Literal["name_type"] = ...) -> t.Optional[t.Literal["dasherize", "lowercase"]]: ... @overload def __getitem__(self, item: t.Literal["model_name"] = ...) -> str: ... @overload @@ -1516,7 +1539,7 @@ class LLMConfig(_ConfigAttr): # fmt: on def __getitem__(self, item: t.LiteralString | t.Any = None) -> t.Any: - """Allowing access LLMConfig as a dictionary. The order will always evaluate as + """Allowing access LLMConfig as a dictionary. The order will always evaluate as. __openllm_*__ > self.key > self.generation_config > self['fine_tune_strategies'] > __openllm_extras__ @@ -1599,7 +1622,6 @@ class LLMConfig(_ConfigAttr): **attrs: The attributes to be added to the new class. This will override any existing attributes with the same name. """ - assert cls.__config__ is not None, "Cannot derivate a LLMConfig without __config__" _new_cfg = {k: v for k, v in attrs.items() if k in attr.fields_dict(_ModelSettingsAttr)} attrs = {k: v for k, v in attrs.items() if k not in _new_cfg} new_cls = types.new_class( @@ -1642,14 +1664,12 @@ class LLMConfig(_ConfigAttr): try: attrs = orjson.loads(json_str) except orjson.JSONDecodeError as err: - raise openllm.exceptions.ValidationError(f"Failed to load JSON: {err}") + raise openllm.exceptions.ValidationError(f"Failed to load JSON: {err}") from None return bentoml_cattr.structure(attrs, cls) @classmethod def model_construct_env(cls, **attrs: t.Any) -> t.Self: - """A helpers that respect configuration values that - sets from environment variables for any given configuration class. - """ + """A helpers that respect configuration values environment variables.""" attrs = {k: v for k, v in attrs.items() if v is not None} model_config = cls.__openllm_env__.config @@ -1696,7 +1716,7 @@ class LLMConfig(_ConfigAttr): return self.model_construct_env(**llm_config_attrs), {k: v for k, v in attrs.items() if k not in key_to_remove} @overload - def to_generation_config(self, return_as_dict: t.Literal[False] = ...) -> transformers.GenerationConfig: + def to_generation_config(self, return_as_dict: t.Literal[False] = False) -> transformers.GenerationConfig: ... @overload @@ -1708,22 +1728,12 @@ class LLMConfig(_ConfigAttr): return config.to_dict() if return_as_dict else config @classmethod - @overload - def to_click_options( - cls, f: t.Callable[..., openllm.LLMConfig] - ) -> F[P, ClickFunctionWrapper[..., openllm.LLMConfig]]: - ... + def to_click_options(cls, f: AnyCallable) -> click.Command: + """Convert current configuration to click options. - @classmethod - @overload - def to_click_options(cls, f: t.Callable[P, O_co]) -> F[P, ClickFunctionWrapper[P, O_co]]: - ... + This can be used as a decorator for click commands. - @classmethod - def to_click_options(cls, f: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: - """ - Convert current model to click options. This can be used as a decorator for click commands. - Note that the identifier for all LLMConfig will be prefixed with '_*', and the generation config + > **Note**: that the identifier for all LLMConfig will be prefixed with '_*', and the generation config will be prefixed with '_generation_*'. """ for name, field in attr.fields_dict(cls.__openllm_generation_class__).items(): @@ -1769,8 +1779,7 @@ bentoml_cattr.register_unstructure_hook_factory( def structure_llm_config(data: DictStrAny, cls: type[LLMConfig]) -> LLMConfig: - """ - Structure a dictionary to a LLMConfig object. + """Structure a dictionary to a LLMConfig object. Essentially, if the given dictionary contains a 'generation_config' key, then we will use it for LLMConfig.generation_config diff --git a/src/openllm/_generation.py b/src/openllm/_generation.py index 09875867..aaf5d4d1 100644 --- a/src/openllm/_generation.py +++ b/src/openllm/_generation.py @@ -14,14 +14,15 @@ """Generation utilities to be reused throughout.""" from __future__ import annotations - import typing as t -import torch - import transformers +if t.TYPE_CHECKING: + import torch + + class StopSequenceCriteria(transformers.StoppingCriteria): """This class used to stop generation when a seq of tokens are met. @@ -42,6 +43,6 @@ class StopSequenceCriteria(transformers.StoppingCriteria): class StopOnTokens(transformers.StoppingCriteria): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: t.Any) -> bool: stop_ids = {50278, 50279, 50277, 1, 0} - return input_ids[0][-1] in stop_ids + return t.cast(int, input_ids[0][-1]) in stop_ids diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 401d1d97..feca2f80 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import collections import functools import hashlib @@ -21,26 +20,25 @@ import inspect import logging import os import re -import subprocess import sys import types import typing as t from abc import ABC from abc import abstractmethod - from pathlib import Path + import attr -import inflection import orjson from huggingface_hub import hf_hub_download import bentoml import openllm +from bentoml._internal.models import ModelStore from bentoml._internal.models.model import ModelSignature -from bentoml._internal.types import ModelSignatureDict -from ._configuration import FineTuneConfig from ._configuration import AdapterType +from ._configuration import FineTuneConfig +from ._quantisation import infer_quantisation_config from .exceptions import ForbiddenAttributeError from .exceptions import GpuNotAvailableError from .utils import DEBUG @@ -49,23 +47,27 @@ from .utils import EnvVarMixin from .utils import LazyLoader from .utils import ReprMixin from .utils import bentoml_cattr +from .utils import codegen from .utils import first_not_none -from .utils import is_bitsandbytes_available +from .utils import get_debug_mode +from .utils import in_docker from .utils import is_peft_available from .utils import is_torch_available -from .utils import is_transformers_supports_kbit from .utils import non_intrusive_setattr -from .utils import pkg -from .utils import validate_is_path, resolve_filepath -from .utils import requires_dependencies from .utils import normalize_attrs_to_model_tokenizer_pair +from .utils import pkg +from .utils import requires_dependencies +from .utils import resolve_filepath +from .utils import validate_is_path # NOTE: We need to do this so that overload can register # correct overloads to typing registry -if hasattr(t, "get_overloads"): +if sys.version_info[:2] >= (3, 11): + from typing import NotRequired from typing import overload else: + from typing_extensions import NotRequired from typing_extensions import overload if t.TYPE_CHECKING: @@ -75,26 +77,22 @@ if t.TYPE_CHECKING: import transformers from bentoml._internal.runner.strategy import Strategy - from ._types import ( - ModelProtocol, - TokenizerProtocol, - LLMRunner, - PeftAdapterOutput, - LLMRunnable as LLMRunnable, - LLMInitAttrs, - AdaptersMapping, - ) - from .models.auto.factory import _BaseAutoLLMClass + from ._types import AdaptersMapping + from ._types import AdaptersTuple + from ._types import DictStrAny + from ._types import LiteralRuntime + from ._types import LLMInitAttrs + from ._types import LLMRunnable + from ._types import LLMRunner + from ._types import PeftAdapterOutput + from ._types import TupleAny from .utils.representation import ReprArgs - DictStrAny = dict[str, t.Any] - TupleAny = tuple[t.Any, ...] - ListAny = list[t.Any] UserDictAny = collections.UserDict[str, t.Any] + ResolvedAdaptersMapping = dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] else: DictStrAny = dict TupleAny = tuple - ListAny = list UserDictAny = collections.UserDict LLMRunnable = bentoml.Runnable LLMRunner = bentoml.Runner @@ -105,6 +103,13 @@ else: logger = logging.getLogger(__name__) +class ModelSignatureDict(t.TypedDict, total=False): + batchable: bool + batch_dim: t.Union[t.Tuple[int, int], int] + input_spec: NotRequired[t.Union[t.Any, t.Tuple[t.Any]]] + output_spec: NotRequired[t.Any] + + def normalise_model_name(name: str) -> str: return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else re.sub("[^a-zA-Z0-9]+", "-", name) @@ -113,7 +118,7 @@ def make_tag( model_id: str, model_version: str | None = None, trust_remote_code: bool = False, - implementation: t.Literal["pt", "flax", "tf"] = "pt", + implementation: LiteralRuntime = "pt", quiet: bool = False, ) -> bentoml.Tag: """Generate a ``bentoml.Tag`` from a given transformers model name. @@ -128,6 +133,7 @@ def make_tag( trust_remote_code: Whether to trust the remote code. Defaults to False. model_version: Optional model version to be saved with this tag. implementation: Given implementation for said LLM. One of t.Literal['pt', 'tf', 'flax'] + quiet: Whether to show warning logs. Default to 'False' Returns: A tuple of ``bentoml.Tag`` and a dict of unused kwargs. @@ -136,13 +142,21 @@ def make_tag( if validate_is_path(model_id): model_id = resolve_filepath(model_id) - if model_version is None: - if not quiet: - logger.warning( - "Given 'model_id=%s' is a path, and 'model_version' is not passed. OpenLLM will generate the version based on the last modified time of this given directory.", - model_id, - ) - model_version = generate_hash_from_file(model_id) + # special cases, if it is the model store, then we return the tags + # this will happens within the container, where we use the relative path + if in_docker() and os.getenv("BENTO_PATH") is not None: + _store = ModelStore(Path(model_id).parent.parent) + tag = _store.list()[0].tag + model_version = tag.version + model_name = tag.name + else: + if model_version is None: # noqa: PLR5501 + if not quiet: + logger.warning( + "Given 'model_id=%s' is a path, and 'model_version' is not passed. OpenLLM will generate the version based on the last modified time of this given directory.", + model_id, + ) + model_version = generate_hash_from_file(model_id) else: config = t.cast( "transformers.PretrainedConfig", @@ -159,16 +173,21 @@ def make_tag( f"Internal errors when parsing config for pretrained {model_id} ('commit_hash' not found)" ) - logger.debug( - "'model_id=%s' will use 'model_version=%s'. The full tag to be saved under model store: '%s-%s:%s'", - model_id, - model_version, - implementation, - model_name, - model_version, - ) + if in_docker() and os.getenv("BENTO_PATH") is not None: + logger.debug("The model will be loaded as relative path within BentoContainer.") + else: + logger.debug( + "'model_id=%s' will use 'model_version=%s'. The full tag to be saved under model store: '%s-%s:%s'", + model_id, + model_version, + implementation, + model_name, + model_version, + ) - return bentoml.Tag.from_taglike(f"{implementation}-{model_name}:{model_version}".strip()) + return bentoml.Tag.from_taglike( + f"{model_name if in_docker() and os.getenv('BENTO_PATH') is not None else implementation + '-' + model_name}:{model_version}".strip() + ) @functools.lru_cache(maxsize=128) @@ -193,7 +212,11 @@ PEFT_CONFIG_NAME = "adapter_config.json" def resolve_peft_config_type(adapter_map: dict[str, str | None] | None): """Resolve the type of the PeftConfig given the adapter_map. + This is similar to how PeftConfig resolve its config type. + + Args: + adapter_map: The given mapping from either SDK or CLI. See CLI docs for more information. """ if adapter_map is None: return @@ -203,30 +226,31 @@ def resolve_peft_config_type(adapter_map: dict[str, str | None] | None): for path_or_adapter_id, name in adapter_map.items(): if _has_set_default: raise ValueError("Only one adapter can be set as default.") + resolve_name = name + if resolve_name is None: + resolve_name = "default" + _has_set_default = True if os.path.isfile(os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)): config_file = os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME) else: try: config_file = hf_hub_download(path_or_adapter_id, PEFT_CONFIG_NAME) - except Exception: - raise ValueError(f"Can't find '{PEFT_CONFIG_NAME}' at '{path_or_adapter_id}'") + except Exception as err: + raise ValueError(f"Can't find '{PEFT_CONFIG_NAME}' at '{path_or_adapter_id}'") from err with open(config_file, "r") as file: resolved_config = orjson.loads(file.read()) # all peft_type should be available in PEFT_CONFIG_NAME _peft_type: AdapterType = resolved_config["peft_type"].lower() if _peft_type not in resolved: resolved[_peft_type] = () - resolved[_peft_type] += ((path_or_adapter_id, name, resolved_config),) - if name == "default": - _has_set_default = True + resolved[_peft_type] += (_AdaptersTuple((path_or_adapter_id, resolve_name, resolved_config)),) return resolved _reserved_namespace = {"config_class", "model", "tokenizer", "import_kwargs"} - -_M = t.TypeVar("_M", bound="ModelProtocol[t.Any]") -_T = t.TypeVar("_T", bound="TokenizerProtocol[t.Any]") +M = t.TypeVar("M", bound="transformers.PreTrainedModel") +T = t.TypeVar("T", bound="t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer]") def _default_post_init(self: LLM[t.Any, t.Any]): @@ -236,21 +260,36 @@ def _default_post_init(self: LLM[t.Any, t.Any]): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -class LLMInterface(ABC, t.Generic[_M, _T]): +class LLMInterface(ABC, t.Generic[M, T]): """This defines the loose contract for all openllm.LLM implementations.""" @property - def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]] | None: + def import_kwargs(self) -> tuple[DictStrAny, DictStrAny] | None: """The default import kwargs to used when importing the model. + This will be passed into 'openllm.LLM.import_model'. It returns two dictionaries: one for model kwargs and one for tokenizer kwargs. + + Returns: + Optional tuple of model kwargs and tokenizer kwargs """ return + def embeddings(self, prompt: str) -> torch.Tensor: + """The implementation for generating text embeddings from given prompt. + + It takes the prompt and output the embeddings for this given LLM. + + Returns: + The embeddings for the given prompt. + """ + raise NotImplementedError + @abstractmethod def generate(self, prompt: str, **preprocess_generate_kwds: t.Any) -> t.Any: - """The main function implementation for generating from given prompt. It takes the prompt - and the generation_kwargs from 'self.sanitize_parameters' and then + """The implementation for text generation from given prompt. + + It takes the prompt and 'generation_kwargs' from 'self.sanitize_parameters' and then pass it to 'self.model.generate'. """ raise NotImplementedError @@ -261,18 +300,20 @@ class LLMInterface(ABC, t.Generic[_M, _T]): stop: list[str], **preprocess_generate_kwds: t.Any, ) -> list[dict[t.Literal["generated_text"], str]]: - """The entrypoint for generating one prompt. This provides additional stop - tokens for generating per token level. This is useful when running with agents, or initial streaming support. + """The entrypoint for generating one prompt. + + This provides additional stop tokens for generating per token level. + This is useful when running with agents, or initial streaming support. """ raise NotImplementedError def generate_iterator(self, prompt: str, **attrs: t.Any) -> t.Iterator[t.Any]: - """An iterator version of generate function.""" + """T iterator version of `generate` function.""" raise NotImplementedError( "Currently generate_iterator requires SSE (Server-side events) support, which is not yet implemented." ) - def sanitize_parameters(self, prompt: str, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: + def sanitize_parameters(self, prompt: str, **attrs: t.Any) -> tuple[str, DictStrAny, DictStrAny]: """This handler will sanitize all attrs and setup prompt text. It takes a prompt that is given by the user, attrs that can be parsed with the prompt. @@ -284,8 +325,7 @@ class LLMInterface(ABC, t.Generic[_M, _T]): return prompt, attrs, attrs def postprocess_generate(self, prompt: str, generation_result: t.Any, **attrs: t.Any) -> t.Any: - """This handler will postprocess generation results from LLM.generate and - then output nicely formatted results (if the LLM decide to do so.) + """This handler will postprocess generation results from LLM.generate and then output nicely formatted results (if the LLM decide to do so.). You can customize how the output of the LLM looks with this hook. By default, it is a simple echo. @@ -294,29 +334,34 @@ class LLMInterface(ABC, t.Generic[_M, _T]): return generation_result def llm_post_init(self): - """This function can be implemented if you need to initialized any additional variables that doesn't - concern OpenLLM internals. - """ + """This function can be implemented if you need to initialized any additional variables that doesn't concern OpenLLM internals.""" pass def import_model(self, *args: t.Any, trust_remote_code: bool, **attrs: t.Any) -> bentoml.Model: """This function can be implemented if default import_model doesn't satisfy your needs. - Note that tokenizer kwds can be accessed via ``llm.llm_parameters`` + + Note that tokenizer attrs can be accessed via ``llm.llm_parameters``. ```python - (model_decls, model_attrs), tokenizer_attrs = llm.llm_parameters + _, tokenizer_attrs = llm.llm_parameters ``` + + By default, `model_decls` and `model_attrs` is already sanitised and concatenated into `args` and `attrs` """ raise NotImplementedError - def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any: - """This function can be implemented to override the default load_model behaviour. See falcon for - example implementation.""" + def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> M: + """This function can be implemented to override the default load_model behaviour. + + See falcon for example implementation. + """ raise NotImplementedError - def load_tokenizer(self, tag: bentoml.Tag, **attrs: t.Any) -> t.Any: - """This function can be implemented to override how to load the tokenizer. See falcon for - example implementation.""" + def load_tokenizer(self, tag: bentoml.Tag, **attrs: t.Any) -> T: + """This function can be implemented to override how to load the tokenizer. + + See falcon for example implementation. + """ raise NotImplementedError # NOTE: All fields below are attributes that can be accessed by users. @@ -339,30 +384,34 @@ class LLMInterface(ABC, t.Generic[_M, _T]): # NOTE: The following will be populated by __init_subclass__, note that these should be immutable. __llm_trust_remote_code__: bool """This is used to determine during 'import_model' whether to trust remote code or not. + This works synonymous with `trust_remote_code` kwarg in transformers Auto classes. If not passed, then by default fallback to config_class['trust_remote_code'] """ - __llm_implementation__: t.Literal["pt", "tf", "flax"] - """This is used to determine which implementation that this LLM has. Usually, this will inferred from - class name, that follows the HuggingFace's naming convention: + __llm_implementation__: LiteralRuntime + """This is used to determine which implementation that this LLM has. + + Usually, this will inferred from class name, that follows the HuggingFace's naming convention: - `OPTForConditionalGeneration` -> `pt` - `TFOPTForConditionalGeneration` -> `tf` - `FlaxOPTForConditionalGeneration` -> `flax` """ - __llm_model__: _M | peft.PeftModel | torch.nn.Module | None + __llm_model__: M | peft.PeftModel | None """A reference to the actual model. Instead of access this directly, you should use `model` property instead.""" - __llm_tokenizer__: _T | None + __llm_tokenizer__: T | None """A reference to the actual tokenizer. Instead of access this directly, you should use `tokenizer` property instead.""" __llm_bentomodel__: bentoml.Model | None """A reference to the bentomodel used for this LLM. Instead of access this directly, you should use `_bentomodel` property instead.""" __llm_adapter_map__: dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] | None """A reference to the the cached LoRA adapter mapping.""" - __llm_custom_load__: t.Callable[[t.Self, t.Any, t.Any], None] | None - """A callable that will be called after the model is loaded. This is set when 'load_model' is implemented""" - __llm_custom_tokenizer__: t.Callable[[t.Self, t.Any], None] | None - """A callable that will be called after the tokenizer is loaded. This is set when 'load_tokenizer' is implemented""" + __llm_custom_import__: bool + """Whether this LLM has a custom import_model""" + __llm_custom_load__: bool + """A boolean to determine whether a custom 'load_model' is implemented""" + __llm_custom_tokenizer__: bool + """A boolean to determine whether a custom 'load_tokenizer' is implemented""" __llm_init_kwargs__: property | None """A check if 'import_kwargs' is implemented in subclass.""" @@ -380,18 +429,24 @@ class LLMInterface(ABC, t.Generic[_M, _T]): tag: bentoml.Tag, adapters_mapping: AdaptersMapping, model_version: str | None, + quantize_method: t.Literal["int8", "int4", "gptq"] | None, /, **attrs: t.Unpack[LLMInitAttrs], ) -> None: - """Generated __attrs_init__ for openllm.LLM""" + """Generated __attrs_init__ for openllm.LLM.""" + + +_AdaptersTuple: type[AdaptersTuple] = codegen.make_attr_tuple_class("AdaptersTuple", ["adapter_id", "name", "config"]) @attr.define(slots=True, repr=False) -class LLM(LLMInterface[_M, _T], ReprMixin): +class LLM(LLMInterface[M, T], ReprMixin): config: openllm.LLMConfig """The config instance to use for this LLM. This will be created based on config_class and available when initialising the LLM.""" + quantization_config: transformers.BitsAndBytesConfig | None + """Quantisation config for quantised model on the fly.""" _model_id: str _runtime: t.Literal["ggml", "transformers"] @@ -401,55 +456,66 @@ class LLM(LLMInterface[_M, _T], ReprMixin): _tag: bentoml.Tag _adapters_mapping: AdaptersMapping _model_version: str + _quantize_method: t.Literal["int8", "int4", "gptq"] | None + + @staticmethod + def _infer_implementation_from_name(name: str) -> tuple[LiteralRuntime, str]: + if name.startswith("Flax"): + return "flax", name[4:] + elif name.startswith("TF"): + return "tf", name[2:] + else: + return "pt", name def __init_subclass__(cls): cd = cls.__dict__ - prefix_class_name_config = cls.__name__ - if prefix_class_name_config.startswith("Flax"): - implementation = "flax" - prefix_class_name_config = prefix_class_name_config[4:] - elif prefix_class_name_config.startswith("TF"): - implementation = "tf" - prefix_class_name_config = prefix_class_name_config[2:] - else: - implementation = "pt" + implementation, config_class_name = cls._infer_implementation_from_name(cls.__name__) cls.__llm_implementation__ = implementation - config_class = openllm.AutoConfig.infer_class_from_name(prefix_class_name_config) + config_class = openllm.AutoConfig.infer_class_from_name(config_class_name) if "__openllm_internal__" in cd: if "config_class" not in cd: cls.config_class = config_class - else: - logger.debug(f"Using config class {cd['config_class']} for {cls.__name__}.") - else: - if "config_class" not in cd: - raise RuntimeError( - "Missing required key 'config_class'. Make sure to define it within the LLM subclass." - ) + logger.debug("Using config class %s for %s.", cls.config_class, cls.__name__) + elif "config_class" not in cd: + raise RuntimeError("Missing required key 'config_class'. Make sure to define it within the LLM subclass.") - if cls.import_model is LLMInterface[_M, _T].import_model: + _custom_import = True + if cls.import_model is LLMInterface[M, T].import_model: # using the default import model if no custom import is set + _custom_import = False setattr(cls, "import_model", openllm.serialisation.import_model) + else: + import_func = getattr(cls, "import_model") - if cls.llm_post_init is LLMInterface[_M, _T].llm_post_init: + def _wrapped_import_model(self: LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any): + # wrapped around custom init to provide some meta compression + # for all decls and attrs + (model_decls, model_attrs), _ = self.llm_parameters + + decls = (*model_decls, *decls) + attrs = {**model_attrs, **attrs} + + return import_func(self, *decls, trust_remote_code=trust_remote_code, **attrs) + + setattr(cls, "import_model", functools.update_wrapper(_wrapped_import_model, cls.import_model)) + + if cls.llm_post_init is LLMInterface[M, T].llm_post_init: # using the default post init if no custom post init is set wrapped_post_init = _default_post_init else: original_post_init = getattr(cls, "llm_post_init") - def wrapped_post_init(self: LLM[_M, _T]): + def wrapped_post_init(self: LLM[M, T]): _default_post_init(self) original_post_init(self) setattr(cls, "llm_post_init", wrapped_post_init) - cls.__llm_custom_load__ = None if cls.load_model is LLMInterface[_M, _T].load_model else cls.load_model - cls.__llm_custom_tokenizer__ = ( - None if cls.load_tokenizer is LLMInterface[_M, _T].load_tokenizer else cls.load_tokenizer - ) - cls.__llm_init_kwargs__ = ( - None if cls.import_kwargs is LLMInterface[_M, _T].import_kwargs else cls.import_kwargs - ) + cls.__llm_custom_import__ = _custom_import + cls.__llm_custom_load__ = False if cls.load_model is LLMInterface[M, T].load_model else True + cls.__llm_custom_tokenizer__ = False if cls.load_tokenizer is LLMInterface[M, T].load_tokenizer else True + cls.__llm_init_kwargs__ = None if cls.import_kwargs is LLMInterface[M, T].import_kwargs else cls.import_kwargs for at in {"bentomodel", "model", "tokenizer", "adapter_map"}: setattr(cls, f"__llm_{at}__", None) @@ -475,7 +541,10 @@ class LLM(LLMInterface[_M, _T], ReprMixin): if self.__llm_model__ is not None and self.bettertransformer: from optimum.bettertransformer import BetterTransformer - self.__llm_model__ = BetterTransformer.reverse(self.__llm_model__) + self.__llm_model__ = t.cast( + M, + BetterTransformer.reverse(t.cast("transformers.PreTrainedModel", self.__llm_model__)), # type: ignore + ) openllm.serialisation.save_pretrained(self, save_directory, **attrs) @@ -494,11 +563,12 @@ class LLM(LLMInterface[_M, _T], ReprMixin): adapter_map: dict[str, str | None] | None = None, quantization_config: transformers.BitsAndBytesConfig | None = None, **attrs: t.Any, - ) -> LLM[_M, _T]: + ) -> LLM[M, T]: """Instantiate a pretrained LLM. - it follows the same design principle as HuggingFace's `from_pretrained` method, plus the following: - Optimization options: + ``LLM.from_pretrained`` follows the same design principle as HuggingFace's `from_pretrained` method, plus the following: + + ### Optimization options: > This is most notable during serving time. @@ -507,7 +577,7 @@ class LLM(LLMInterface[_M, _T], ReprMixin): > Currently, the above two options are mutually exclusive. - Adapter options: + ### Adapter options: > This is used in conjunction with the fine-tuning features @@ -528,6 +598,7 @@ class LLM(LLMInterface[_M, _T], ReprMixin): will use `config_class` to construct default configuration. quantize: The quantization to use for this LLM. Defaults to None. Possible values include int8, int4 and gptq. + runtime: Optional runtime to run this LLM. Default to 'transformers'. 'ggml' supports is working in progress. quantization_config: The quantization config (`transformers.BitsAndBytesConfig`) to use. Note that this is mutually exclusive with `quantize` bettertransformer: Whether to use BetterTransformer with this model. Defaults to False. adapter_id: The [LoRA](https://arxiv.org/pdf/2106.09685.pdf) pretrained id or local path to use for this LLM. Defaults to None. @@ -550,26 +621,7 @@ class LLM(LLMInterface[_M, _T], ReprMixin): """'quantization_config' and 'quantize' are mutually exclusive. Either customise your quantization_config or use the 'quantize' argument.""" ) - - # 8 bit configuration - int8_threshold = attrs.pop("llm_int8_threshhold", 6.0) - cpu_offloading = attrs.pop("llm_int8_enable_fp32_cpu_offload", False) - int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None) - int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False) - # 4 bit configuration - int4_compute_dtype = attrs.pop("bnb_4bit_compute_dtype", torch.bfloat16) - int4_quant_type = attrs.pop("bnb_4bit_quant_type", "nf4") - int4_use_double_quant = attrs.pop("bnb_4bit_use_double_quant", True) - - # NOTE: Quantization setup - # quantize is a openllm.LLM feature, where we can quantize the model - # with bitsandbytes or quantization aware training. if quantization_config is None and quantize is not None: - if not is_bitsandbytes_available(): - raise RuntimeError( - "Quantization requires bitsandbytes to be installed. Make " - "sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'" - ) logger.debug( "'quantize' is not None. %s will use a default 'quantization_config' for %s. " "If you want to customise the quantization config, make sure to pass your " @@ -577,48 +629,7 @@ class LLM(LLMInterface[_M, _T], ReprMixin): cls.__name__, quantize, ) - if quantize == "int8": - if int8_skip_modules is None: - int8_skip_modules = [] - if "lm_head" not in int8_skip_modules and cls.config_class.__openllm_model_type__ == "causal_lm": - logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__) - int8_skip_modules.append("lm_head") - quantization_config = transformers.BitsAndBytesConfig( - load_in_8bit=True, - llm_int8_enable_fp32_cpu_offload=cpu_offloading, - llm_int8_threshhold=int8_threshold, - llm_int8_skip_modules=int8_skip_modules, - llm_int8_has_fp16_weight=int8_has_fp16_weight, - ) - elif quantize == "int4": - if is_transformers_supports_kbit(): - quantization_config = transformers.BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=int4_compute_dtype, - bnb_4bit_quant_type=int4_quant_type, - bnb_4bit_use_double_quant=int4_use_double_quant, - ) - else: - logger.warning( - "'quantize' is set to int4, while the current transformers version %s does not support " - "k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore " - "make sure to install the latest version of transformers either via PyPI or " - "from git source: 'pip install git+https://github.com/huggingface/transformers'.", - pkg.pkg_version_info("transformers"), - ) - elif quantize == "gptq": - # TODO: support GPTQ loading quantization - raise NotImplementedError("GPTQ is not supported yet.") - if model_id is None: - raise RuntimeError( - "'quantize=%s' requires passing custom path to quantized weights as we are unable to load " - "the model on the fly. See https://github.com/qwopqwop200/GPTQ-for-LLaMa for " - "instruction on how to quantize '%s' with GPTQ.", - quantize, - cls.__name__, - ) - else: - raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantize} instead.") + quantization_config, attrs = infer_quantisation_config(cls, quantize, **attrs) # NOTE: Fine-tuning setup if adapter_map and adapter_id: @@ -636,49 +647,52 @@ class LLM(LLMInterface[_M, _T], ReprMixin): "LoRA adapter requires 'peft' to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'" ) - if llm_config is not None: - if DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3: - logger.debug("Using provided LLMConfig to initialize LLM instead of from default: %r", llm_config) - else: + if llm_config is None: llm_config = cls.config_class.model_construct_env(**attrs) # The rests of the kwargs that is not used by the config class should be stored into __openllm_extras__. attrs = llm_config["extras"] - try: - _tag = bentoml.Tag.from_taglike(model_id) - except ValueError: - _tag = make_tag( - model_id, - model_version=model_version, - trust_remote_code=cfg_cls.__openllm_trust_remote_code__, - implementation=cls.__llm_implementation__, - quiet=True, - ) - assert _tag.version is not None, "Failed to resolve model version." + _tag = cls._infer_tag_from_model_id(model_id, model_version) + if _tag.version is None: + raise RuntimeError("Failed to resolve model version.") return cls( + *args, model_id=model_id, llm_config=llm_config, - *args, bettertransformer=str(bettertransformer).upper() in ENV_VARS_TRUE_VALUES, quantization_config=quantization_config, - _tag=_tag, - _model_version=_tag.version, + _quantize_method=quantize, _adapters_mapping=resolve_peft_config_type(adapter_map), _runtime=runtime, + _model_version=_tag.version, + _tag=_tag, **attrs, ) + @classmethod + def _infer_tag_from_model_id(cls, model_id: str, model_version: str | None) -> bentoml.Tag: + try: + return bentoml.Tag.from_taglike(model_id) + except ValueError: + return make_tag( + model_id, + model_version=model_version, + trust_remote_code=cls.config_class.__openllm_trust_remote_code__, + implementation=cls.__llm_implementation__, + quiet=True, + ) + def __init__( self, + *args: t.Any, model_id: str, llm_config: openllm.LLMConfig, - *args: t.Any, - bettertransformer: bool | None = None, - quantization_config: transformers.BitsAndBytesConfig | None = None, - _adapters_mapping: dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] - | None = None, + bettertransformer: bool | None, + quantization_config: transformers.BitsAndBytesConfig | None, + _adapters_mapping: AdaptersMapping | None, _tag: bentoml.Tag, + _quantize_method: t.Literal["int8", "int4", "gptq"] | None, _runtime: t.Literal["ggml", "transformers"], _model_version: str, **attrs: t.Any, @@ -706,20 +720,20 @@ class LLM(LLMInterface[_M, _T], ReprMixin): ```python def import_model( self, - model_id: str, - tag: bentoml.Tag, *args: t.Any, - tokenizer_kwds: dict[str, t.Any], + trust_remote_code: bool, **attrs: t.Any, ): + _, tokenizer_attrs = self.llm_parameters + return bentoml.transformers.save_model( tag, transformers.AutoModelForCausalLM.from_pretrained( - model_id, device_map="auto", torch_dtype=torch.bfloat16, **attrs + self.model_id, device_map="auto", torch_dtype=torch.bfloat16, **attrs ), custom_objects={ "tokenizer": transformers.AutoTokenizer.from_pretrained( - model_id, padding_size="left", **tokenizer_kwds + self.model_id, padding_size="left", **tokenizer_attrs ) }, ) @@ -731,14 +745,14 @@ class LLM(LLMInterface[_M, _T], ReprMixin): ```python dolly_v2_runner = openllm.Runner( - "dolly-v2", _tokenizer_padding_size="left", torch_dtype=torch.bfloat16, device_map="gpu" + "dolly-v2", _tokenizer_padding_size="left", torch_dtype=torch.bfloat16, device_map="cuda" ) ``` Note: If you implement your own `import_model`, then `import_kwargs` will be the - default kwargs for every load. You can still override those via ``openllm.Runner``. + base kwargs. You can still override those via ``openllm.Runner``. - Note that this tag will be generated based on `self.default_id` or the given `pretrained` kwds. + Note that this tag will be generated based on `self.default_id`. passed from the __init__ constructor. ``llm_post_init`` can also be implemented if you need to do any additional @@ -762,10 +776,10 @@ class LLM(LLMInterface[_M, _T], ReprMixin): llm_config: The config to use for this LLM. Defaults to None. If not passed, OpenLLM will use `config_class` to construct default configuration. bettertransformer: Whether to use BetterTransformer with this model. Defaults to False. + quantization_config: ``transformers.BitsAndBytesConfig`` configuration, or 'gptq' denoting this model to be loaded with GPTQ. *args: The args to be passed to the model. **attrs: The kwargs to be passed to the model. """ - # low_cpu_mem_usage is only available for model # this is helpful on system with low memory to avoid OOM low_cpu_mem_usage = attrs.pop("low_cpu_mem_usage", True) @@ -799,6 +813,7 @@ class LLM(LLMInterface[_M, _T], ReprMixin): _tag, _adapters_mapping, _model_version, + _quantize_method, ) # handle trust_remote_code self.__llm_trust_remote_code__ = self._model_attrs.pop("trust_remote_code", self.config["trust_remote_code"]) @@ -863,19 +878,22 @@ class LLM(LLMInterface[_M, _T], ReprMixin): return normalise_model_name(self._model_id) @property - def identifying_params(self) -> dict[str, t.Any]: + def identifying_params(self) -> DictStrAny: return { "configuration": self.config.model_dump_json().decode(), "model_ids": orjson.dumps(self.config["model_ids"]).decode(), } @property - def llm_parameters(self) -> tuple[tuple[tuple[t.Any, ...], dict[str, t.Any]], dict[str, t.Any]]: - """Returning the processed model and tokenizer parameters to be used with - 'import_model' or any other place that requires loading model and tokenizer. + def llm_parameters(self) -> tuple[tuple[tuple[t.Any, ...], DictStrAny], DictStrAny]: + """Returning the processed model and tokenizer parameters. + + These can then be used with 'import_model' or any other place that requires loading model and tokenizer. See 'openllm.cli.download_models' for example usage. - It returns a tuple of (model_args, model_kwargs) & tokenizer_kwargs + + Returns: + ``tuple``: It returns a tuple of (model_args, model_kwargs) & tokenizer_kwargs """ return (self._model_decls, self._model_attrs), self._tokenizer_attrs @@ -885,35 +903,24 @@ class LLM(LLMInterface[_M, _T], ReprMixin): def ensure_model_id_exists(self) -> bentoml.Model: """This utility function will download the model if it doesn't exist yet. + Make sure to call this function if 'ensure_available' is not set during Auto LLM initialisation. + + The equivalent for ``openllm.Runner`` is ``openllm.Runner.download_model``. """ - - output = subprocess.check_output( - [ - sys.executable, - "-m", - "openllm", - "download", - self.config["start_name"], - self.model_id, - "--machine", - "--implementation", - self.__llm_implementation__, - "--model-version", - self._model_version, - "--runtime", - self.runtime, - ], - env=os.environ.copy(), + additional_args = ["--machine"] + if not get_debug_mode(): + additional_args.append("--quiet") + return openllm.import_model( + self.config["start_name"], + model_id=self.model_id, + model_version=self._model_version, + runtime=self.runtime, + implementation=self.__llm_implementation__, + quantize=self._quantize_method, + additional_args=additional_args, ) - # NOTE: This usually only concern OpenLLM devs. - pattern = r"^__tag__:[^:\n]+:[^:\n]+" - matched = re.search(pattern, output.decode("utf-8").strip(), re.MULTILINE) - assert matched is not None, f"Failed to find tag from output: {output}" - _, _, tag = matched.group(0).partition(":") - - return bentoml.models.get(tag) @property def _bentomodel(self) -> bentoml.Model: @@ -922,7 +929,7 @@ class LLM(LLMInterface[_M, _T], ReprMixin): return self.__llm_bentomodel__ @property - def model(self) -> _M: + def model(self) -> M: """The model to use for this LLM. This shouldn't be set at runtime, rather let OpenLLM handle it.""" # Run check for GPU if self.config["requires_gpu"] and len(openllm.utils.gpu_count()) < 1: @@ -930,79 +937,72 @@ class LLM(LLMInterface[_M, _T], ReprMixin): if self.__llm_model__ is None: self.__llm_model__ = t.cast( - _M, openllm.serialisation.load_model(self, *self._model_decls, **self._model_attrs) + M, openllm.serialisation.load_model(self, *self._model_decls, **self._model_attrs) ) - return t.cast(_M, self.__llm_model__) + return t.cast(M, self.__llm_model__) @property - def tokenizer(self) -> _T: + def tokenizer(self) -> T: """The tokenizer to use for this LLM. This shouldn't be set at runtime, rather let OpenLLM handle it.""" if self.__llm_tokenizer__ is None: - self.__llm_tokenizer__ = t.cast(_T, openllm.serialisation.load_tokenizer(self)) + self.__llm_tokenizer__ = t.cast(T, openllm.serialisation.load_tokenizer(self)) return self.__llm_tokenizer__ + def _default_ft_config(self, _adapter_type: AdapterType, inference_mode: bool) -> FineTuneConfig: + strategy = first_not_none( + self.config["fine_tune_strategies"].get(_adapter_type), + default=FineTuneConfig(adapter_type=_adapter_type, llm_config_class=self.config_class), + ) + return strategy.eval() if inference_mode else strategy.train() + def _transpose_adapter_mapping( self, inference_mode: bool = True, use_cache: bool = True, - ) -> dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]]: - assert self._adapters_mapping is not None, "LoRA mapping is not set up correctly." + ) -> ResolvedAdaptersMapping: + if self._adapters_mapping is None: + raise ValueError("LoRA mapping is not set up correctly.") - if not use_cache: - logger.debug( - "'use_cache' is set to False. This means the adapter mapping resolution will not be cached. This should only be used during training." - ) - - if self.__llm_adapter_map__ is not None and use_cache: + if use_cache and self.__llm_adapter_map__ is not None: # early out if we already serialized everything. return self.__llm_adapter_map__ - adapter_map: dict[AdapterType, dict[str | t.Literal["default"], tuple[peft.PeftConfig, str]]] = {} + if not use_cache: + logger.debug("Adapter mapping resolution will not be cached. This should only be used during training.") + + adapter_map: ResolvedAdaptersMapping = {k: {} for k in self._adapters_mapping} # this is a temporary check to accept the first option name as 'default' # then we will raise Error when the optional_name is set to None in next iteration. _converted_first_none = False - for _adapter_type, _adapter_tuple in self._adapters_mapping.items(): - if _adapter_type not in adapter_map: - adapter_map[_adapter_type] = {} - default_config = self.config["fine_tune_strategies"].get( - _adapter_type, FineTuneConfig(adapter_type=_adapter_type, llm_config_class=self.config_class) - ) - default_config = default_config.eval() if inference_mode else default_config.train() - for pretrained_or_peft_id, optional_name, resolved_mapping in _adapter_tuple: - if not optional_name: - if not _converted_first_none: - _converted_first_none = True - optional_name = "default" - else: - raise ValueError( - f"{self.__class__.__name__} doesn't know how to resolve adapter_name None mapping: {pretrained_or_peft_id, resolved_mapping}" - ) - assert isinstance(optional_name, str) # optional_name should all be resolved here - if optional_name == "default": - adapter_map[_adapter_type][optional_name] = ( - default_config.with_config(**resolved_mapping).to_peft_config(), - pretrained_or_peft_id, - ) - else: - adapter_map[_adapter_type][optional_name] = ( - FineTuneConfig( - adapter_type=_adapter_type, - adapter_config=resolved_mapping, - inference_mode=inference_mode, - llm_config_class=self.config_class, - ).to_peft_config(), - pretrained_or_peft_id, + for _adapter_type, _adapters_tuples in self._adapters_mapping.items(): + default_config = self._default_ft_config(_adapter_type, inference_mode) + for adapter in _adapters_tuples: + if not adapter.name and _converted_first_none: + raise ValueError( + f"{self.__class__.__name__} doesn't know how to resolve adapter_name None mapping: {adapter.adapter_id, adapter.config}" ) + name = adapter.name + if name is None: + _converted_first_none = True + name = "default" + peft_config = ( + default_config.with_config(**adapter.config).to_peft_config() + if name == "default" + else FineTuneConfig( + adapter_type=_adapter_type, + adapter_config=adapter.config, + inference_mode=inference_mode, + llm_config_class=self.config_class, + ).to_peft_config() + ) + adapter_map[_adapter_type][name] = (peft_config, adapter.adapter_id) if self.__llm_adapter_map__ is None and use_cache: self.__llm_adapter_map__ = adapter_map - - return self.__llm_adapter_map__ - return adapter_map @requires_dependencies("peft", extra="fine-tune") - def prepare_for_training(self, adapter_type: AdapterType = "lora", **attrs: t.Any) -> tuple[_M, _T]: + def prepare_for_training(self, adapter_type: AdapterType = "lora", **attrs: t.Any) -> tuple[peft.PeftModel, T]: if pkg.pkg_version_info("peft")[:2] >= (0, 4): from peft import prepare_model_for_kbit_training else: @@ -1027,18 +1027,20 @@ class LLM(LLMInterface[_M, _T], ReprMixin): adapter_type: AdapterType = "lora", load_adapters: t.Literal["all"] | list[str] | None = None, use_cache: bool = True, - ) -> peft.PeftModel | _M | torch.nn.Module: - """Apply given LoRA mapping to the model. Note that the base model can still - be accessed via self.model.get_base_model(). - """ - assert self.model, "Internal error: Model is not loaded correctly." - assert self.__llm_model__ is not None + ) -> peft.PeftModel | M: + """Apply given LoRA mapping to the model. - # early out if _adapters_mapping is empty or it is already wrapped - # with peft. + Note that the base model can still be accessed via self.model.get_base_model(). + """ + assert self.__llm_model__ is not None # noqa: S101 + + # early out if _adapters_mapping is empty or it is already wrapped with peft. if not self._adapters_mapping: logger.debug("No adapter mapping is found. Skip applying adapter.") return self.__llm_model__ + if isinstance(self.__llm_model__, peft.PeftModel): + logger.debug("Model is already wrapped with peft. Skip applying adapter.") + return self.__llm_model__ _mapping = self._transpose_adapter_mapping(inference_mode=inference_mode, use_cache=use_cache) if adapter_type not in _mapping: @@ -1046,40 +1048,17 @@ class LLM(LLMInterface[_M, _T], ReprMixin): f"Given adapter type {adapter_type} is not supported. Please choose from {list(_mapping.keys())}" ) adapter_mapping = _mapping[adapter_type] - default_config, peft_model_id = adapter_mapping.pop("default", None) - if default_config is None: - raise ValueError( - "There is no 'default' mapping. Please check the adapter mapping and report this bug to the OpenLLM team." - ) - # the below shared similar logics with `get_peft_model` - # TODO: Support PromptLearningConfig - if default_config.task_type not in peft.MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance( - default_config, peft.PromptLearningConfig - ): - logger.debug( - "Given task type '%s' is not supported by peft. This means it can be a custom PeftModel implementation. Make sure the adapter is loaded manually before running inference.", - default_config.task_type, - ) - self.__llm_model__ = peft.PeftModel(self.__llm_model__, default_config) - else: - # this is not ideal to serialize like this, wait until https://github.com/huggingface/peft/pull/612 - # is merged - peft_class = peft.MODEL_TYPE_TO_PEFT_MODEL_MAPPING[default_config.task_type] - if t.cast("str | None", default_config.base_model_name_or_path) is not None: - kwargs: dict[str, t.Any] = {"is_trainable": not inference_mode} - if "config" in inspect.signature(peft_class.from_pretrained).parameters: - kwargs["config"] = default_config - else: - kwargs.update(dict(default_config.to_dict().items())) - self.__llm_model__ = peft_class.from_pretrained(self.__llm_model__, peft_model_id, **kwargs) - else: - # in this case, the given base_model_name_or_path is None. This will be hit during training - self.__llm_model__ = peft_class(self.__llm_model__, default_config) + self.__llm_model__ = self._wrap_default_peft_model(adapter_mapping, inference_mode=inference_mode) + + if not isinstance(self.__llm_model__, peft.PeftModel): + # We hit this branch during inference + # TODO: load multiple adapters + return self.__llm_model__ # now we loop through the rest with add_adapter if len(adapter_mapping) > 0: - for adapter_name, _peft_config in adapter_mapping.items(): + for adapter_name, (_peft_config, _) in adapter_mapping.items(): self.__llm_model__.add_adapter(adapter_name, _peft_config) # optionally load adapters. In case of multiple adapters, or on Runner, @@ -1097,14 +1076,53 @@ class LLM(LLMInterface[_M, _T], ReprMixin): return self.__llm_model__ + def _wrap_default_peft_model(self, adapter_mapping: dict[str, tuple[peft.PeftConfig, str]], inference_mode: bool): + assert self.__llm_model__ is not None, "Error: Model is not loaded correctly" # noqa: S101 + if isinstance(self.__llm_model__, peft.PeftModel): + logger.warning("Model is already wrapped with peft. Skip wrapping with default peft model.") + return self.__llm_model__ + + if "default" not in adapter_mapping: + raise ValueError( + "There is no 'default' mapping. Please check the adapter mapping and report this bug to the OpenLLM team." + ) + + default_config, peft_model_id = adapter_mapping.pop("default") + + # the below shared similar logics with `get_peft_model` + # TODO: Support PromptLearningConfig + if default_config.task_type not in peft.MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance( + default_config, peft.PromptLearningConfig + ): + logger.debug( + "Given task type '%s' is not supported by peft. Make sure the adapter is loaded manually before running inference.", + default_config.task_type, + ) + model = peft.PeftModel(self.__llm_model__, default_config) + else: + # XXX: this is not ideal to serialize like this, maybe for fine-tune we will only support 0.4.0 + # onwards. For now, keep this logic here. + peft_class = peft.MODEL_TYPE_TO_PEFT_MODEL_MAPPING[default_config.task_type] + if default_config.base_model_name_or_path: + kwargs: DictStrAny = {"is_trainable": not inference_mode} + if "config" in inspect.signature(peft_class.from_pretrained).parameters: + kwargs["config"] = default_config + else: + kwargs.update(dict(default_config.to_dict().items())) + # BUG: This hits during inference, need fixing + model = peft_class.from_pretrained(self.__llm_model__, peft_model_id, **kwargs) + else: + # in this case, the given base_model_name_or_path is None. This will be hit during training + model = peft_class(self.__llm_model__, default_config) + return model + # order of these fields matter here, make sure to sync it with - # openllm.models.auto.factory._BaseAutoLLMClass.for_model + # openllm.models.auto.factory.BaseAutoLLMClass.for_model def to_runner( self, models: list[bentoml.Model] | None = None, max_batch_size: int | None = None, max_latency_ms: int | None = None, - method_configs: dict[str, ModelSignatureDict | ModelSignature] | None = None, scheduling_strategy: type[Strategy] | None = None, ) -> LLMRunner: """Convert this LLM into a Runner. @@ -1114,7 +1132,6 @@ class LLM(LLMInterface[_M, _T], ReprMixin): By default, this will be determined from the model_name. max_batch_size: The maximum batch size for the runner. max_latency_ms: The maximum latency for the runner. - method_configs: The method configs for the runner. strategy: The strategy to use for this runner. embedded: Whether to run this runner in embedded mode. scheduling_strategy: Whether to create a custom scheduling strategy for this Runner. @@ -1126,6 +1143,7 @@ class LLM(LLMInterface[_M, _T], ReprMixin): - 'name': will be generated by OpenLLM, hence users don't shouldn't worry about this. The generated name will be 'llm--runner' (ex: llm-dolly-v2-runner, llm-chatglm-runner) - 'embedded': Will be disabled by default. There is no reason to run LLM in embedded mode. + - 'method_configs': The method configs for the runner will be managed internally by OpenLLM. """ models = models if models is not None else [] @@ -1138,192 +1156,30 @@ class LLM(LLMInterface[_M, _T], ReprMixin): from ._strategies import CascadingResourceStrategy scheduling_strategy = CascadingResourceStrategy - else: - logger.debug("Using custom scheduling strategy: %s", scheduling_strategy) generate_sig = ModelSignature.from_dict(ModelSignatureDict(batchable=False)) generate_iterator_sig = ModelSignature.from_dict(ModelSignatureDict(batchable=True)) - if method_configs is None: - method_configs = { - "generate": generate_sig, - "generate_one": generate_sig, - "generate_iterator": generate_iterator_sig, - } - else: - signatures = ModelSignature.convert_signatures_dict(method_configs) - generate_sig = first_not_none(signatures.get("generate"), default=generate_sig) - generate_iterator_sig = first_not_none(signatures.get("generate_iterator"), default=generate_iterator_sig) - - class _Runnable(bentoml.Runnable): - SUPPORTED_RESOURCES = ("nvidia.com/gpu", "amd.com/gpu", "cpu") - SUPPORTS_CPU_MULTI_THREADING = True - - def __init__(__self: _Runnable): - # NOTE: The side effect of this line - # is that it will load the imported model during - # runner startup. So don't remove it!! - __self.model = self.model # keep a loaded reference - if self.adapters_mapping is not None: - logger.info("Applying LoRA to %s...", self.runner_name) - self.apply_adapter(inference_mode=True, load_adapters="all") - - @bentoml.Runnable.method(batchable=False) - def set_adapter(__self, adapter_name: str) -> dict[t.Literal["success", "error_msg"], bool | str]: - if not is_peft_available(): - return { - "success": False, - "error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'", - } - if self.__llm_adapter_map__ is None: - return { - "success": False, - "error_msg": "No adapters available for current running server.", - } - if not isinstance(self.model, peft.PeftModel): - return {"success": False, "error_msg": "Model is not a PeftModel"} - try: - self.model.set_adapter(adapter_name) - return {"success": True, "error_msg": ""} - except ValueError: - logger.info("Adapter %s not found", adapter_name) - return { - "success": False, - "error_msg": f"Adapter {adapter_name} not found. Available adapters: {list(self.model.peft_config)}", - } - - @bentoml.Runnable.method( - batchable=generate_sig.batchable, - batch_dim=generate_sig.batch_dim, - input_spec=generate_sig.input_spec, - output_spec=generate_sig.output_spec, - ) - def __call__(__self, prompt: str, **attrs: t.Any) -> list[t.Any]: - return self.generate(prompt, **attrs) - - @bentoml.Runnable.method( - batchable=generate_sig.batchable, - batch_dim=generate_sig.batch_dim, - input_spec=generate_sig.input_spec, - output_spec=generate_sig.output_spec, - ) - def generate(__self, prompt: str, **attrs: t.Any) -> list[t.Any]: - return self.generate(prompt, **attrs) - - @bentoml.Runnable.method( - batchable=generate_sig.batchable, - batch_dim=generate_sig.batch_dim, - input_spec=generate_sig.input_spec, - output_spec=generate_sig.output_spec, - ) - def generate_one( - __self, prompt: str, stop: list[str], **attrs: t.Any - ) -> list[dict[t.Literal["generated_text"], str]]: - return self.generate_one(prompt, stop, **attrs) - - @bentoml.Runnable.method( - batchable=generate_iterator_sig.batchable, - batch_dim=generate_iterator_sig.batch_dim, - input_spec=generate_iterator_sig.input_spec, - output_spec=generate_iterator_sig.output_spec, - ) - def generate_iterator(__self, prompt: str, **attrs: t.Any) -> t.Generator[t.Any, None, None]: - yield self.generate_iterator(prompt, **attrs) - - def available_adapters(__self: LLMRunner) -> PeftAdapterOutput: - if not is_peft_available(): - return { - "success": False, - "result": {}, - "error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'", - } - if self.__llm_adapter_map__ is None: - return { - "success": False, - "result": {}, - "error_msg": "No adapters available for current running server.", - } - if not isinstance(__self.model, peft.PeftModel): - return {"success": False, "result": {}, "error_msg": "Model is not a PeftModel"} - return {"success": True, "result": __self.model.peft_config, "error_msg": ""} - - def _wrapped_generate_run(__self: LLMRunner, prompt: str, **kwargs: t.Any) -> t.Any: - """Wrapper for runner.generate.run() to handle the prompt and postprocessing. - - This will be used for LangChain API. - - Usage: - ```python - runner = openllm.Runner("dolly-v2", init_local=True) - runner("What is the meaning of life?") - ``` - """ - prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **kwargs) - generated_result = __self.generate.run(prompt, **generate_kwargs) - return self.postprocess_generate(prompt, generated_result, **postprocess_kwargs) - - def _wrapped_repr_keys(_: LLMRunner) -> set[str]: - return {"config", "llm_type", "runner_methods", "runtime", "llm_tag"} - - def _wrapped_repr_args(__self: LLMRunner) -> ReprArgs: - yield "runner_methods", { - method.name: { - "batchable": method.config.batchable, - "batch_dim": method.config.batch_dim if method.config.batchable else None, - } - for method in __self.runner_methods - } - yield "config", self.config.model_dump(flatten=True) - yield "llm_type", __self.llm_type - yield "runtime", self.runtime - yield "llm_tag", self.tag # NOTE: returning the two langchain API's to the runner - return types.new_class( - inflection.camelize(self.config["model_name"]) + "Runner", - (bentoml.Runner,), - exec_body=lambda ns: ns.update( - { - "llm_type": self.llm_type, - "identifying_params": self.identifying_params, - "llm_tag": self.tag, - "llm": self, # NOTE: self reference to LLM - "config": self.config, - "peft_adapters": property(fget=available_adapters), - "download_model": self.ensure_model_id_exists, - "__call__": _wrapped_generate_run, - "__module__": self.__module__, - "__doc__": self.config["env"].start_docstring, - "__repr__": ReprMixin.__repr__, - "__repr_keys__": property(_wrapped_repr_keys), - "__repr_args__": _wrapped_repr_args, - } - ), - )( - types.new_class( - inflection.camelize(self.config["model_name"]) + "Runnable", - (_Runnable,), - {}, - lambda ns: ns.update( - { - "SUPPORTED_RESOURCES": ("nvidia.com/gpu", "amd.com/gpu") - if self.config["requires_gpu"] - else ("nvidia.com/gpu", "amd.com/gpu", "cpu"), - "__module__": self.__module__, - "__doc__": self.config["env"].start_docstring, - } - ), - ), + return llm_runner_class(self)( + llm_runnable_class(self, generate_sig, generate_iterator_sig), name=self.runner_name, embedded=False, models=models, max_batch_size=max_batch_size, max_latency_ms=max_latency_ms, - method_configs=bentoml_cattr.unstructure(method_configs), + method_configs=bentoml_cattr.unstructure( + { + "generate": generate_sig, + "generate_one": generate_sig, + "generate_iterator": generate_iterator_sig, + } + ), scheduling_strategy=scheduling_strategy, ) def predict(self, prompt: str, **attrs: t.Any) -> t.Any: - """The scikit-compatible API for self(...)""" + """The scikit-compatible API for self(...).""" return self.__call__(prompt, **attrs) def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: @@ -1376,11 +1232,11 @@ def Runner( model_name: str, ensure_available: bool | None = None, init_local: bool = False, - implementation: t.Literal["pt", "flax", "tf"] | None = None, + implementation: LiteralRuntime | None = None, llm_config: openllm.LLMConfig | None = None, **attrs: t.Any, ) -> LLMRunner: - """Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models' + """Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'. The behaviour of ensure_available that is synonymous to `AutoLLM.for_model` depends on `init_local`. By default, `ensure_available` is synonymous to `init_local`, meaning on the service when creating @@ -1406,6 +1262,7 @@ def Runner( If False, make sure the model is available locally. implementation: The given Runner implementation one choose for this Runner. By default, it is retrieved from the enviroment variable of the respected model_name. For example: 'flan-t5' -> "OPENLLM_FLAN_T5_FRAMEWORK" + llm_config: Optional ``openllm.LLMConfig`` to initialise this ``openllm.LLMRunner``. init_local: If True, it will initialize the model locally. This is useful if you want to run the model locally. (Symmetrical to bentoml.Runner.init_local()) **attrs: The rest of kwargs will then be passed to the LLM. Refer to the LLM documentation for the kwargs @@ -1423,10 +1280,7 @@ def Runner( if implementation is None: implementation = EnvVarMixin(model_name)["framework_value"] - runner = t.cast( - "_BaseAutoLLMClass", - openllm[implementation], # type: ignore (internal API) - ).create_runner( + runner = openllm.infer_auto_class(implementation).create_runner( model_name, llm_config=llm_config, ensure_available=ensure_available if ensure_available is not None else init_local, @@ -1437,3 +1291,159 @@ def Runner( runner.init_local(quiet=True) return runner + + +def method_signature(sig: ModelSignature) -> ModelSignatureDict: + return bentoml_cattr.unstructure(sig) + + +class SetAdapterOutput(t.TypedDict): + success: bool + message: str + + +def llm_runnable_class( + self: openllm.LLM[M, T], + generate_sig: ModelSignature, + generate_iterator_sig: ModelSignature, +) -> type[LLMRunnable]: + class _Runnable(bentoml.Runnable): + SUPPORTED_RESOURCES = ("nvidia.com/gpu", "amd.com/gpu", "cpu") + SUPPORTS_CPU_MULTI_THREADING = True + + def __init__(__self: _Runnable): + # NOTE: The side effect of this line + # is that it will load the imported model during + # runner startup. So don't remove it!! + __self.model = self.model # keep a loaded reference + if self.adapters_mapping is not None: + logger.info("Applying LoRA to %s...", self.runner_name) + self.apply_adapter(inference_mode=True, load_adapters="all") + + @bentoml.Runnable.method(batchable=False) + def set_adapter(__self, adapter_name: str) -> SetAdapterOutput: + success = False + message = None + if not is_peft_available(): + message = "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'" + elif self.__llm_adapter_map__ is None: + message = "No adapters available for current running server." + elif not isinstance(__self.model, peft.PeftModel): + message = "Model is not a PeftModel" + if message is not None: + return SetAdapterOutput(success=success, message=message) + + try: + t.cast("peft.PeftModel", __self.model).set_adapter(adapter_name) + return SetAdapterOutput(success=True, message=f"Successfully set current adapter to {adapter_name}") + except ValueError: + logger.info("Adapter %s not found", adapter_name) + return SetAdapterOutput( + success=success, + message=f"Adapter {adapter_name} not found. Available adapters: {list(t.cast('peft.PeftModel', __self.model).peft_config)}", + ) + + @bentoml.Runnable.method(**method_signature(generate_sig)) + def __call__(__self, prompt: str, **attrs: t.Any) -> list[t.Any]: + return self.generate(prompt, **attrs) + + @bentoml.Runnable.method(**method_signature(generate_sig)) + def generate(__self, prompt: str, **attrs: t.Any) -> list[t.Any]: + return self.generate(prompt, **attrs) + + @bentoml.Runnable.method(**method_signature(generate_sig)) + def generate_one( + __self, prompt: str, stop: list[str], **attrs: t.Any + ) -> list[dict[t.Literal["generated_text"], str]]: + return self.generate_one(prompt, stop, **attrs) + + @bentoml.Runnable.method(**method_signature(generate_iterator_sig)) + def generate_iterator(__self, prompt: str, **attrs: t.Any) -> t.Generator[t.Any, None, None]: + yield self.generate_iterator(prompt, **attrs) + + return types.new_class( + self.__class__.__name__ + "Runnable", + (_Runnable,), + {}, + lambda ns: ns.update( + { + "SUPPORTED_RESOURCES": ("nvidia.com/gpu", "amd.com/gpu") + if self.config["requires_gpu"] + else ("nvidia.com/gpu", "amd.com/gpu", "cpu"), + "__module__": self.__module__, + "__doc__": self.config["env"].start_docstring, + } + ), + ) + + +def llm_runner_class(self: openllm.LLM[M, T]) -> type[LLMRunner]: + def available_adapters(__self: LLMRunner) -> PeftAdapterOutput: + if not is_peft_available(): + return { + "success": False, + "result": {}, + "error_msg": "peft is not available. Make sure to install: 'pip install \"openllm[fine-tune]\"'", + } + if self.__llm_adapter_map__ is None: + return { + "success": False, + "result": {}, + "error_msg": "No adapters available for current running server.", + } + if not isinstance(__self.model, peft.PeftModel): + return {"success": False, "result": {}, "error_msg": "Model is not a PeftModel"} + return {"success": True, "result": __self.model.peft_config, "error_msg": ""} + + def _wrapped_generate_run(__self: LLMRunner, prompt: str, **kwargs: t.Any) -> t.Any: + """Wrapper for runner.generate.run() to handle the prompt and postprocessing. + + This will be used for LangChain API. + + Usage: + ```python + runner = openllm.Runner("dolly-v2", init_local=True) + runner("What is the meaning of life?") + ``` + """ + prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **kwargs) + generated_result = __self.generate.run(prompt, **generate_kwargs) + return self.postprocess_generate(prompt, generated_result, **postprocess_kwargs) + + def _wrapped_repr_keys(_: LLMRunner) -> set[str]: + return {"config", "llm_type", "runner_methods", "runtime", "llm_tag"} + + def _wrapped_repr_args(__self: LLMRunner) -> ReprArgs: + yield "runner_methods", { + method.name: { + "batchable": method.config.batchable, + "batch_dim": method.config.batch_dim if method.config.batchable else None, + } + for method in __self.runner_methods + } + yield "config", self.config.model_dump(flatten=True) + yield "llm_type", __self.llm_type + yield "runtime", self.runtime + yield "llm_tag", self.tag + + return types.new_class( + self.__class__.__name__ + "Runner", + (bentoml.Runner,), + exec_body=lambda ns: ns.update( + { + "llm_type": self.llm_type, + "identifying_params": self.identifying_params, + "llm_tag": self.tag, + "llm": self, # NOTE: self reference to LLM + "config": self.config, + "peft_adapters": property(fget=available_adapters), + "download_model": self.ensure_model_id_exists, + "__call__": _wrapped_generate_run, + "__module__": self.__module__, + "__doc__": self.config["env"].start_docstring, + "__repr__": ReprMixin.__repr__, + "__repr_keys__": property(_wrapped_repr_keys), + "__repr_args__": _wrapped_repr_args, + } + ), + ) diff --git a/src/openllm/_package.py b/src/openllm/_package.py index b6be7996..eef14e0f 100644 --- a/src/openllm/_package.py +++ b/src/openllm/_package.py @@ -11,17 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Any build-related utilities. This is used for CI. +"""Build-related utilities. Some of these utilities are mainly used for 'openllm.build'. + +These utilities will stay internal, and its API can be changed or updated without backward-compatibility. """ from __future__ import annotations - import importlib.metadata import logging import os -import re -import subprocess -import sys import typing as t from pathlib import Path @@ -34,7 +31,6 @@ from simple_di import Provide from simple_di import inject import bentoml -import openllm from bentoml._internal.bento.build_config import BentoBuildConfig from bentoml._internal.bento.build_config import DockerOptions from bentoml._internal.bento.build_config import PythonOptions @@ -56,7 +52,7 @@ from .utils import resolve_user_filepath if t.TYPE_CHECKING: from fs.base import FS - from bentoml._internal.bento import BentoStore + import openllm logger = logging.getLogger(__name__) @@ -135,7 +131,8 @@ def construct_python_options( env: EnvVarMixin = llm.config["env"] framework_envvar = env["framework_value"] if framework_envvar == "flax": - assert is_flax_available(), f"Flax is not available, while {env.framework} is set to 'flax'" + if not is_flax_available(): + raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'") packages.extend( [ handle_package_version("flax", has_dockerfile_template), @@ -144,7 +141,8 @@ def construct_python_options( ] ) elif framework_envvar == "tf": - assert is_tf_available(), f"TensorFlow is not available, while {env.framework} is set to 'tf'" + if not is_tf_available(): + raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'") candidates = ( "tensorflow", "tensorflow-cpu", @@ -170,7 +168,8 @@ def construct_python_options( except importlib.metadata.PackageNotFoundError: pass else: - assert is_torch_available(), "PyTorch is not available. Make sure to have it locally installed." + if not is_torch_available(): + raise ValueError("PyTorch is not available. Make sure to have it locally installed.") packages.extend([handle_package_version("torch", has_dockerfile_template)]) wheels: list[str] = [] @@ -206,7 +205,7 @@ def construct_docker_options( "OPENLLM_MODEL": llm.config["model_name"], "OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'", "BENTOML_DEBUG": str(get_debug_mode()), - "BENTOML_CONFIG_OPTIONS": _bentoml_config_options, + "BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'", } if adapter_map: @@ -257,7 +256,8 @@ def create_bento( logger.info("Building Bento for '%s'", llm.config["start_name"]) if adapter_map is not None: - assert build_ctx is not None, "build_ctx is required when 'adapter_map' is not None" + if build_ctx is None: + raise ValueError("build_ctx is required when 'adapter_map' is not None") updated_mapping: dict[str, str | None] = {} for adapter_id, name in adapter_map.items(): try: @@ -321,7 +321,7 @@ def create_bento( # new behaviour with BentoML models model = _model_store.get(f"{model_framework}-{model_type}") except bentoml.exceptions.NotFound: - raise OpenLLMException(f"Failed to find models for {llm.config['start_name']}") + raise OpenLLMException(f"Failed to find models for {llm.config['start_name']}") from None # NOTE: the model_id_path here are only used for setting this environment variable within the container # built with for BentoLLM. @@ -330,10 +330,12 @@ def create_bento( with open(service_path, "r") as f: service_contents = f.readlines() + rel_path = f"../models/{model.tag.path()}" + for it in service_contents: if codegen.OPENLLM_MODEL_ID in it: service_contents[service_contents.index(it)] = ( - codegen.ModelIdFormatter(str(model.tag)).vformat(it)[: -(len(codegen.OPENLLM_MODEL_ID) + 3)] + "\n" + codegen.ModelIdFormatter(rel_path).vformat(it)[: -(len(codegen.OPENLLM_MODEL_ID) + 3)] + "\n" ) if "__bento_name__" in it: service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag)) @@ -346,70 +348,3 @@ def create_bento( bento._fs.writetext(service_fs_path, script) return bento.save() - - -@inject -def build( - model_name: str, - *, - model_id: str | None = None, - model_version: str | None = None, - quantize: t.Literal["int8", "int4", "gptq"] | None = None, - bettertransformer: bool | None = None, - adapter_map: dict[str, str | None] | None = None, - build_ctx: str | None = None, - extra_dependencies: tuple[str, ...] | None = None, - workers_per_resource: int | float | None = None, - overwrite_existing_bento: bool = False, - runtime: t.Literal["ggml", "transformers"] = "transformers", - dockerfile_template: str | None = None, - bento_store: BentoStore = Provide[BentoMLContainer.bento_store], -) -> bentoml.Bento: - """Package a LLM into a Bento. - - The LLM will be built into a BentoService with the following structure: - if quantize is passed, it will instruct the model to be quantized dynamically during serving time. - if bettertransformer is passed, it will instruct the model to use BetterTransformer during serving time. - - Other parameters including model_name, model_id and attrs will be passed to the LLM class itself. - """ - args = [sys.executable, "-m", "openllm", "build", model_name, "--machine", "--runtime", runtime] - - if quantize and bettertransformer: - raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.") - - if quantize: - args.extend(["--quantize", quantize]) - if bettertransformer: - args.append("--bettertransformer") - - if model_id: - args.extend(["--model-id", model_id]) - if build_ctx: - args.extend(["--build-ctx", build_ctx]) - if extra_dependencies: - args.extend([f"--enable-features={f}" for f in extra_dependencies]) - if workers_per_resource: - args.extend(["--workers-per-resource", str(workers_per_resource)]) - if overwrite_existing_bento: - args.append("--overwrite") - if adapter_map: - args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()]) - if model_version: - args.extend(["--model-version", model_version]) - if dockerfile_template: - args.extend(["--dockerfile-template", dockerfile_template]) - - try: - output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd()) - except subprocess.CalledProcessError as e: - logger.error("Exception caught while building %s", model_name, exc_info=e) - if e.stderr: - raise OpenLLMException(e.stderr.decode("utf-8")) from None - raise OpenLLMException(str(e)) from None - # NOTE: This usually only concern BentoML devs. - pattern = r"^__tag__:[^:\n]+:[^:\n]+" - matched = re.search(pattern, output.decode("utf-8").strip(), re.MULTILINE) - assert matched is not None, f"Failed to find tag from output: {output}" - _, _, tag = matched.group(0).partition(":") - return bentoml.get(tag, _bento_store=bento_store) diff --git a/src/openllm/_prompt.py b/src/openllm/_prompt.py index d96bc62f..708980bf 100644 --- a/src/openllm/_prompt.py +++ b/src/openllm/_prompt.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations - import string import typing as t @@ -20,10 +19,10 @@ import typing as t class PromptFormatter(string.Formatter): """This PromptFormatter is largely based on langchain's implementation.""" - def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> str: + def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.LiteralString: if len(args) > 0: raise ValueError("Positional arguments are not supported") - return super().vformat(format_string, args, kwargs) + return t.cast("t.LiteralString", super().vformat(format_string, args, kwargs)) def check_unused_args( self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any] diff --git a/src/openllm/_quantisation.py b/src/openllm/_quantisation.py index 3a2faba5..31c17de5 100644 --- a/src/openllm/_quantisation.py +++ b/src/openllm/_quantisation.py @@ -11,3 +11,92 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations +import logging +import typing as t + +from .utils import LazyLoader +from .utils import is_bitsandbytes_available +from .utils import is_transformers_supports_kbit +from .utils import pkg + + +if t.TYPE_CHECKING: + import torch + + import openllm + import transformers + + from ._types import DictStrAny +else: + torch = LazyLoader("torch", globals(), "torch") + transformers = LazyLoader("transformers", globals(), "transformers") + +logger = logging.getLogger(__name__) + +QuantiseMode = t.Literal["int8", "int4", "gptq"] + + +def infer_quantisation_config( + cls: type[openllm.LLM[t.Any, t.Any]], quantise: QuantiseMode, **attrs: t.Any +) -> tuple[transformers.BitsAndBytesConfig | t.Any, DictStrAny]: + # 8 bit configuration + int8_threshold = attrs.pop("llm_int8_threshhold", 6.0) + int8_enable_fp32_cpu_offload = attrs.pop("llm_int8_enable_fp32_cpu_offload", False) + int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None) + int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False) + + def create_int8_config(int8_skip_modules: list[str] | None): + if int8_skip_modules is None: + int8_skip_modules = [] + if "lm_head" not in int8_skip_modules and cls.config_class.__openllm_model_type__ == "causal_lm": + logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__) + int8_skip_modules.append("lm_head") + return transformers.BitsAndBytesConfig( + load_in_8bit=True, + llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload, + llm_int8_threshhold=int8_threshold, + llm_int8_skip_modules=int8_skip_modules, + llm_int8_has_fp16_weight=int8_has_fp16_weight, + ) + + # 4 bit configuration + int4_compute_dtype = attrs.pop("bnb_4bit_compute_dtype", torch.bfloat16) + int4_quant_type = attrs.pop("bnb_4bit_quant_type", "nf4") + int4_use_double_quant = attrs.pop("bnb_4bit_use_double_quant", True) + + # NOTE: Quantization setup + # quantize is a openllm.LLM feature, where we can quantize the model + # with bitsandbytes or quantization aware training. + if not is_bitsandbytes_available(): + raise RuntimeError( + "Quantization requires bitsandbytes to be installed. Make " + "sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'" + ) + if quantise == "int8": + quantisation_config = create_int8_config(int8_skip_modules) + elif quantise == "int4": + if is_transformers_supports_kbit(): + quantisation_config = transformers.BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=int4_compute_dtype, + bnb_4bit_quant_type=int4_quant_type, + bnb_4bit_use_double_quant=int4_use_double_quant, + ) + else: + logger.warning( + "'quantize' is set to int4, while the current transformers version %s does not support " + "k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore " + "make sure to install the latest version of transformers either via PyPI or " + "from git source: 'pip install git+https://github.com/huggingface/transformers'.", + pkg.pkg_version_info("transformers"), + ) + logger.warning("OpenLLM will fallback to 8-bit quantization.") + quantisation_config = create_int8_config(int8_skip_modules) + elif quantise == "gptq": + # TODO: support GPTQ loading quantization + raise NotImplementedError("GPTQ is not supported yet.") + else: + raise ValueError(f"'quantize' must be one of ['int8', 'int4', 'gptq'], got {quantise} instead.") + + return quantisation_config, attrs diff --git a/src/openllm/_schema.py b/src/openllm/_schema.py index e91bedd6..c8631cd0 100644 --- a/src/openllm/_schema.py +++ b/src/openllm/_schema.py @@ -11,11 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Schema definition for OpenLLM. This can be use for client interaction. -""" +"""Schema definition for OpenLLM. This can be use for client interaction.""" from __future__ import annotations - import functools import typing as t @@ -23,6 +20,8 @@ import attr import inflection import openllm +from openllm._configuration import GenerationConfig +from openllm.utils import bentoml_cattr if t.TYPE_CHECKING: @@ -79,6 +78,14 @@ class GenerationOutput: configuration: t.Dict[str, t.Any] """A mapping of configuration values for given system.""" + @property + def marshaled_config(self) -> GenerationConfig: + return bentoml_cattr.structure(self.configuration, GenerationConfig) + + @property + def unmarshaled(self) -> dict[str, t.Any]: + return bentoml_cattr.unstructure(self) + @attr.frozen(slots=True) class MetadataOutput: diff --git a/src/openllm/_service.py b/src/openllm/_service.py index 433bcdcc..12b3d1a5 100644 --- a/src/openllm/_service.py +++ b/src/openllm/_service.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -The service definition for running any LLMService. +"""The service definition for running any LLMService. Note that the line `model = ...` is a special line and should not be modified. This will be handled by openllm internally to generate the correct model service when bundling the LLM to a Bento. @@ -22,7 +21,6 @@ This will ensure that 'bentoml serve llm-bento' will work accordingly. The generation code lives under utils/codegen.py """ from __future__ import annotations - import os import typing as t import warnings @@ -136,7 +134,7 @@ async def hf_agent(request: Request) -> Response: except orjson.JSONDecodeError as err: raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None - stop = input_data.parameters.pop("stop", "\n") + stop = input_data.parameters.pop("stop", ["\n"]) try: resp = await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters) return JSONResponse(resp, status_code=200) @@ -150,9 +148,11 @@ svc.mount_asgi_app(hf_app, path="/hf") async def list_adapter_v1(_: Request) -> Response: - res = runner.peft_adapters - if res["success"]: - res["result"] = {k: v.to_dict() for k, v in res["result"].items()} + res: dict[str, t.Any] = {} + if runner.peft_adapters["success"] is True: + res["result"] = {k: v.to_dict() for k, v in runner.peft_adapters["result"].items()} + res["success"] = runner.peft_adapters["success"] + res["error_msg"] = runner.peft_adapters["error_msg"] return JSONResponse(res, status_code=200) diff --git a/src/openllm/_strategies.py b/src/openllm/_strategies.py index 2dc4e3db..9ec9e270 100644 --- a/src/openllm/_strategies.py +++ b/src/openllm/_strategies.py @@ -13,8 +13,6 @@ # limitations under the License. from __future__ import annotations - -import functools import logging import math import os @@ -23,19 +21,19 @@ import typing as t import psutil -import bentoml -import openllm from bentoml._internal.resource import Resource from bentoml._internal.resource import get_resource from bentoml._internal.resource import system_resources from bentoml._internal.runner.strategy import THREAD_ENVS from bentoml._internal.runner.strategy import Strategy -from .utils import LazyType +from .exceptions import OpenLLMException from .utils import ReprMixin if t.TYPE_CHECKING: + import bentoml + ListIntStr = list[int | str] else: ListIntStr = list @@ -43,28 +41,46 @@ else: logger = logging.getLogger(__name__) -class AmdGpuResource(Resource[t.List[int]], resource_id="amd.com/gpu"): +class AmdGpuResource(Resource[t.List[str]], resource_id="amd.com/gpu"): @classmethod - def from_spec(cls, spec: int | str | list[str | int]) -> list[int]: - if not isinstance(spec, (int, str)) and not LazyType(ListIntStr).isinstance(spec): + def from_spec(cls, spec: t.Any) -> list[str]: + if not isinstance(spec, (int, str, list)): raise TypeError("AMD GPU device IDs must be int, str or a list specifing the exact GPUs to use.") + try: if isinstance(spec, int): + if spec == -1: + return [] if spec < -1: raise ValueError - return list(range(spec)) + return [str(i) for i in range(spec)] elif isinstance(spec, str): - return cls.from_spec(int(spec)) + try: + return cls.from_spec(int(spec)) + except ValueError: + if spec.startswith("GPU"): + return [spec] + raise ValueError else: - return [int(x) for x in spec] + return [str(x) for x in spec] except ValueError: - raise openllm.exceptions.OpenLLMException(f"Invalid AMD GPU resource limit '{spec}'. ") + raise OpenLLMException(f"Invalid AMD GPU resource limit '{spec}'.") - @classmethod # type: ignore (overload) - @functools.lru_cache(maxsize=1) - def from_system(cls) -> list[int]: + @classmethod + def from_system(cls) -> list[str]: """Retrieve AMD GPU from system, currently only supports on Linux. - This assumes that ROCm is setup correctly.""" + + This assumes that ROCm is setup correctly. + """ + cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") + if cuda_visible_devices in ("", "-1"): + return [] + if cuda_visible_devices is not None: + cuda_visible_devices = cuda_visible_devices.split(",") + if "-1" in cuda_visible_devices: + cuda_visible_devices = cuda_visible_devices[: cuda_visible_devices.index("-1")] + return cuda_visible_devices + if not psutil.LINUX: logger.debug("AMD GPU resource is only supported on Linux.") return [] @@ -84,7 +100,7 @@ class AmdGpuResource(Resource[t.List[int]], resource_id="amd.com/gpu"): num = c_uint32(0) ret = rocmsmi.rsmi_num_monitor_devices(byref(num)) if ret == rsmi_status_t.RSMI_STATUS_SUCCESS: - return list(range(num.value)) + return [str(i) for i in range(num.value)] return [] except Exception as err: logger.debug("Failed to setup AMD GPU resource: %s", err) @@ -93,18 +109,22 @@ class AmdGpuResource(Resource[t.List[int]], resource_id="amd.com/gpu"): sys.path.remove("/opt/rocm/libexec/rocm_smi") @classmethod - def validate(cls, val: list[int]): - if any(gpu_index < 0 for gpu_index in val): - raise openllm.exceptions.OpenLLMException(f"Negative GPU device in {val}.") - if any(gpu_index >= len(cls.from_system()) for gpu_index in val): - raise openllm.exceptions.OpenLLMException( - f"GPU device index in {val} is greater than the system available: {cls.from_system()}" - ) + def validate(cls, val: list[str]): + for gpu_index_or_literal in val: + try: + idx = int(gpu_index_or_literal) + except ValueError: + raise OpenLLMException(f"Invalid AMD GPU device index: {val}") + if int(idx) < 0: + raise OpenLLMException(f"Negative GPU device in {val}.") + if int(idx) >= len(cls.from_system()): + raise OpenLLMException( + f"GPU device index in {val} is greater than the system available: {cls.from_system()}" + ) class CascadingResourceStrategy(Strategy, ReprMixin): - """This is rather an extension of bentoml._internal.runner.strategy.DefaultStrategy - where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource. + """This is extends the default BentoML strategy where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource. It also respect CUDA_VISIBLE_DEVICES for both AMD and NVIDIA GPU. See https://rocm.docs.amd.com/en/develop/understand/gpu_isolation.html#cuda-visible-devices @@ -147,9 +167,9 @@ class CascadingResourceStrategy(Strategy, ReprMixin): ) if runnable_class.SUPPORTS_CPU_MULTI_THREADING: - if isinstance(workers_per_resource, float): + if isinstance(workers_per_resource, float) and workers_per_resource < 1.0: raise ValueError("Fractional CPU multi threading support is not yet supported.") - return workers_per_resource + return int(workers_per_resource) return math.ceil(cpus) * workers_per_resource @@ -167,11 +187,13 @@ class CascadingResourceStrategy(Strategy, ReprMixin): workers_per_resource: int | float, worker_index: int, ) -> dict[str, t.Any]: - """ + """Get worker env for this given worker_index. + Args: - runnable_class : The runnable class to be run. - resource_request : The resource request of the runnable. - worker_index : The index of the worker, start from 0. + runnable_class: The runnable class to be run. + resource_request: The resource request of the runnable. + workers_per_resource: # of workers per resource. + worker_index: The index of the worker, start from 0. """ cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", None) disabled = cuda_env in ("", "-1") diff --git a/src/openllm/_types.py b/src/openllm/_types.py index 2c58002c..a67d1876 100644 --- a/src/openllm/_types.py +++ b/src/openllm/_types.py @@ -11,35 +11,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Types definition for OpenLLM. +"""Types definition for OpenLLM. Note that this module SHOULD NOT BE IMPORTED DURING RUNTIME, as this serve only for typing purposes. It will raises a RuntimeError if this is imported eagerly. """ from __future__ import annotations - import typing as t if not t.TYPE_CHECKING: raise RuntimeError(f"{__name__} should not be imported during runtime") -import click + import bentoml -import openllm -import transformers + from ._configuration import AdapterType -from bentoml._internal.runner.runnable import RunnableMethod -from bentoml._internal.runner.runner import RunnerMethod +if t.TYPE_CHECKING: + import click + import peft + import openllm + import transformers + from bentoml._internal.runner.runnable import RunnableMethod + from bentoml._internal.runner.runner import RunnerMethod + +AnyCallable = t.Callable[..., t.Any] DictStrAny = dict[str, t.Any] ListAny = list[t.Any] +ListStr = list[str] TupleAny = tuple[t.Any, ...] P = t.ParamSpec("P") O_co = t.TypeVar("O_co", covariant=True) +LiteralRuntime: t.TypeAlias = t.Literal["pt", "tf", "flax"] +T = t.TypeVar("T") +Ts = t.TypeVarTuple("Ts") class ClickFunctionWrapper(t.Protocol[P, O_co]): @@ -83,9 +91,19 @@ class TokenizerProtocol(_StubsMixin[_MT], t.Protocol): ... -PeftAdapterOutput = dict[t.Literal["success", "result", "error_msg"], bool | str | dict[t.Any, t.Any]] +class PeftAdapterOutput(t.TypedDict): + success: bool + result: dict[str, peft.PeftConfig] + error_msg: str -AdaptersMapping = dict[AdapterType, tuple[tuple[str | None, str | None, dict[str, t.Any]], ...]] | None + +class AdaptersTuple(TupleAny): + adapter_id: str + name: str | None + config: DictStrAny + + +AdaptersMapping = dict[AdapterType, tuple[AdaptersTuple, ...]] | None class LLMRunnable(bentoml.Runnable): diff --git a/src/openllm/cli.py b/src/openllm/cli.py index d97f0454..75120b33 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -11,25 +11,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -CLI utilities for OpenLLM. +"""CLI utilities for OpenLLM. -This extends BentoML's internal CLI CommandGroup. +This module also contains the SDK to call ``start`` and ``build`` from SDK + +Start any LLM: + +```python +openllm.start("falcon", model_id='tiiuae/falcon-7b-instruct') +``` + +Build a BentoLLM + +```python +bento = openllm.build("falcon") +``` + +Import any LLM into local store +```python +bentomodel = openllm.import_model("falcon", model_id='tiiuae/falcon-7b-instruct') +``` """ from __future__ import annotations - import functools -import tempfile -import pkgutil -import yaml -import subprocess import importlib.util import inspect import itertools import logging import os +import pkgutil import re +import subprocess import sys +import tempfile import time import traceback import typing as t @@ -41,6 +55,7 @@ import fs.errors import inflection import orjson import psutil +import yaml from bentoml_cli.utils import BentoMLCommandGroup from bentoml_cli.utils import opt_callback from simple_di import Provide @@ -59,14 +74,19 @@ from .utils import LazyLoader from .utils import LazyType from .utils import analytics from .utils import bentoml_cattr +from .utils import codegen from .utils import configure_logging +from .utils import dantic from .utils import first_not_none from .utils import get_debug_mode from .utils import get_quiet_mode -from .utils import gpu_count, dantic +from .utils import gpu_count +from .utils import is_jupyter_available +from .utils import is_jupytext_available +from .utils import is_notebook_available from .utils import is_peft_available from .utils import is_torch_available -from .utils import is_transformers_supports_agent, is_jupyter_available, is_jupytext_available, is_notebook_available +from .utils import is_transformers_supports_agent from .utils import resolve_user_filepath from .utils import set_debug_mode from .utils import set_quiet_mode @@ -75,10 +95,12 @@ from .utils import set_quiet_mode if t.TYPE_CHECKING: import torch + from ._types import AnyCallable from ._types import ClickFunctionWrapper - from ._types import F + from ._types import DictStrAny + from ._types import ListStr + from ._types import LiteralRuntime from ._types import P - from .models.auto.factory import _BaseAutoLLMClass ServeCommand = t.Literal["serve", "serve-grpc"] OutputLiteral = t.Literal["json", "pretty", "porcelain"] @@ -89,6 +111,19 @@ else: torch = LazyLoader("torch", globals(), "torch") +# NOTE: We need to do this so that overload can register +# correct overloads to typing registry +if sys.version_info[:2] >= (3, 11): + from typing import overload +else: + from typing_extensions import overload + +if sys.version_info[:2] >= (3, 12): + from typing import override +else: + from typing_extensions import override + + logger = logging.getLogger(__name__) COLUMNS = int(os.environ.get("COLUMNS", 120)) @@ -162,44 +197,18 @@ def workers_per_resource_option(factory: t.Any, build: bool = False): See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy for more information. By default, this is set to 1. - > **Note**: ``--workers-per-resource`` will also accept the following strategies: - > - ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models. - > - ``conserved``: This will determine the number of available GPU resources, and only assign - > one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is - > equivalent to ``--workers-per-resource 0.25``. + **Note**: ``--workers-per-resource`` will also accept the following strategies: + + - ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models. + + - ``conserved``: This will determine the number of available GPU resources, and only assign one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is equivalent to ``--workers-per-resource 0.25``. """ if build: help_str += """\n **Note**: The workers value passed into 'build' will determine how the LLM can be provisioned in Kubernetes as well as in standalone container. This will ensure it has the same effect with 'openllm start --workers ...'""" - return factory.option( - "--workers-per-resource", - default=None, - help=help_str, - callback=parse_workers_per_resource_callback, - required=False, - ) - - -_wpr_strategies = {"round_robin", "conserved"} - - -def parse_workers_per_resource_callback(_: click.Context, param: click.Parameter, value: str | float | None) -> float: - if value is None: - return 1.0 - - if isinstance(value, str): - if value == "round_robin": - return 1.0 - elif value == "conserved": - return float(1 / len(gpu_count())) - else: - try: - value = float(value) - except ValueError: - raise ValueError(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies.") - return float(value) + return factory.option("--workers-per-resource", default=None, help=help_str, type=str, required=False) def quantize_option(factory: t.Any, build: bool = False, model_env: EnvVarMixin | None = None): @@ -264,9 +273,10 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): NUMBER_OF_COMMON_PARAMS = 4 # parameters in common_params + 1 faked group option header @staticmethod - def common_params(f: F[P, t.Any]) -> ClickFunctionWrapper[..., t.Any]: + def common_params(f: AnyCallable): """This is not supposed to be used with unprocessed click function. - This should be used a the last currying from common_params -> usage_tracking -> exception_handling + + This should be used a the last currying from common_params -> usage_tracking -> exception_handling. """ # The following logics is similar to one of BentoMLCommandGroup @@ -300,14 +310,13 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): return f(*args, **attrs) - return t.cast("ClickFunctionWrapper[..., t.Any]", wrapper) + return wrapper @staticmethod - def usage_tracking( - func: ClickFunctionWrapper[..., t.Any], group: click.Group, **attrs: t.Any - ) -> ClickFunctionWrapper[..., t.Any]: + def usage_tracking(func: AnyCallable, group: click.Group, **attrs: t.Any) -> AnyCallable: """This is not supposed to be used with unprocessed click function. - This should be used a the last currying from common_params -> usage_tracking -> exception_handling + + This should be used a the last currying from common_params -> usage_tracking -> exception_handling. """ command_name = attrs.get("name", func.__name__) @@ -336,14 +345,13 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): analytics.track(event) raise - return t.cast("ClickFunctionWrapper[..., t.Any]", wrapper) + return wrapper @staticmethod - def exception_handling( - func: ClickFunctionWrapper[..., t.Any], group: click.Group, **attrs: t.Any - ) -> ClickFunctionWrapper[..., t.Any]: + def exception_handling(func: AnyCallable, group: click.Group, **attrs: t.Any) -> ClickFunctionWrapper[..., t.Any]: """This is not supposed to be used with unprocessed click function. - This should be used a the last currying from common_params -> usage_tracking -> exception_handling + + This should be used a the last currying from common_params -> usage_tracking -> exception_handling. """ command_name = attrs.get("name", func.__name__) @@ -371,7 +379,9 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): return start_command_factory(bentoml.get(cmd_name), _context_settings=_CONTEXT_SETTINGS) except bentoml.exceptions.NotFound: pass - raise click.BadArgumentUsage(f"{cmd_name} is not a valid model identifier supported by OpenLLM.") + raise click.BadArgumentUsage( + f"{cmd_name} is not a valid model identifier supported by OpenLLM." + ) from None elif ctx.command.name == "start-grpc": try: return _cached_grpc[cmd_name] @@ -383,7 +393,9 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): ) except bentoml.exceptions.NotFound: pass - raise click.BadArgumentUsage(f"{cmd_name} is not a valid model identifier supported by OpenLLM.") + raise click.BadArgumentUsage( + f"{cmd_name} is not a valid model identifier supported by OpenLLM." + ) from None return super().get_command(ctx, cmd_name) def list_commands(self, ctx: click.Context) -> list[str]: @@ -392,18 +404,20 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): return super().list_commands(ctx) - def command(self, *args: t.Any, **attrs: t.Any) -> F[[t.Callable[P, t.Any]], click.Command]: - """Override the default 'cli.command' with supports for aliases for given command, and it - wraps the implementation with common parameters. - """ + @override + def command(self, *args: t.Any, **attrs: t.Any): + """Override the default 'cli.command' with supports for aliases for given command, and it wraps the implementation with common parameters.""" if "context_settings" not in attrs: attrs["context_settings"] = {} if "max_content_width" not in attrs["context_settings"]: attrs["context_settings"]["max_content_width"] = 120 aliases = attrs.pop("aliases", None) - def wrapper(f: F[P, t.Any]) -> click.Command: - name = f.__name__.lower().replace("_", "-") + def wrapper(f: AnyCallable) -> click.Command: + name = f.__name__.lower() + if name.endswith("_command"): + name = name[:-8] + name = name.replace("_", "-") attrs.setdefault("help", inspect.getdoc(f)) attrs.setdefault("name", name) @@ -412,9 +426,6 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): # Wrap into OpenLLM tracking wrapped = self.usage_tracking(wrapped, self, **attrs) # Wrap into exception handling - if "do_not_track" in attrs: - # We hit this branch when ctx.invoke the function - attrs.pop("do_not_track") wrapped = self.exception_handling(wrapped, self, **attrs) # move common parameters to end of the parameters list @@ -434,49 +445,29 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): return cmd - # XXX: The current type coercion is not ideal, but we can really - # loosely define it - return t.cast("F[[t.Callable[..., t.Any]], click.Command]", wrapper) - - def group(self, *args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Callable[P, t.Any]], click.Group]: - aliases = kwargs.pop("aliases", None) - - def decorator(f: t.Callable[P, t.Any]): - # create the main group - grp = super(BentoMLCommandGroup, self).group(*args, **kwargs)(f) - - if aliases is not None: - assert grp.name - self._commands[grp.name] = aliases - self._aliases.update({k: grp.name for k in aliases}) - - return grp - - return decorator + return wrapper @click.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="openllm") @click.version_option(__version__, "--version", "-v") def cli(): - """ - \b + """\b ██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗ ██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║ ██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║ ██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║ ╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║ - ╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝ + ╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝. \b An open platform for operating large language models in production. Fine-tune, serve, deploy, and monitor any LLMs with ease. - """ + """ # noqa: D205 @cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start", aliases=["start-http"]) -def start_cli(): - """ - Start any LLM as a REST server. +def start_command(): + """Start any LLM as a REST server. \b ```bash @@ -486,9 +477,8 @@ def start_cli(): @cli.group(cls=OpenLLMCommandGroup, context_settings=_CONTEXT_SETTINGS, name="start-grpc") -def start_grpc_cli(): - """ - Start any LLM as a gRPC server. +def start_grpc_command(): + """Start any LLM as a gRPC server. \b ```bash @@ -510,7 +500,7 @@ else: def parse_serve_args(serve_grpc: bool): - """Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`""" + """Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`.""" from bentoml_cli.cli import cli command = "serve" if not serve_grpc else "serve-grpc" @@ -519,9 +509,7 @@ def parse_serve_args(serve_grpc: bool): help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]", ) - def decorator( - f: t.Callable[t.Concatenate[int, str | None, P], openllm.LLMConfig] - ) -> ClickFunctionWrapper[P, openllm.LLMConfig]: + def decorator(f: t.Callable[t.Concatenate[int, str | None, P], openllm.LLMConfig]): serve_command = cli.commands[command] # The first variable is the argument bento # and the last three are shared default, which we don't need. @@ -535,7 +523,7 @@ def parse_serve_args(serve_grpc: bool): # type can be determine from default value attrs.pop("type") param_decls = (*attrs.pop("opts"), *attrs.pop("secondary_opts")) - f = t.cast("WrappedServeFunction[P]", cog.optgroup.option(*param_decls, **attrs)(f)) + f = cog.optgroup.option(*param_decls, **attrs)(f) return group(f) @@ -546,9 +534,7 @@ _http_server_args = parse_serve_args(False) _grpc_server_args = parse_serve_args(True) -def start_decorator( - llm_config: openllm.LLMConfig, serve_grpc: bool = False -) -> t.Callable[[t.Callable[P, t.Any]], F[P, t.Any]]: +def start_decorator(llm_config: openllm.LLMConfig, serve_grpc: bool = False): opts = [ llm_config.to_click_options, _http_server_args if not serve_grpc else _grpc_server_args, @@ -610,7 +596,7 @@ def start_decorator( - `--adapter-id /path/to/adapter` (local adapter) - - `--adapter-id remote/adapter` (remote adapter from HuggingFace Hub) + j - `--adapter-id remote/adapter` (remote adapter from HuggingFace Hub) - `--adapter-id remote/adapter:eng_lora` (two previous adapter options with the given adapter_name) @@ -629,9 +615,10 @@ def start_decorator( callback=_id_callback, metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]", ), + click.option("--return-process", is_flag=True, default=False, help="Internal use only.", hidden=True), ] - def decorator(f: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]: + def decorator(f: AnyCallable) -> AnyCallable: for opt in reversed(opts): f = opt(f) return f @@ -644,8 +631,8 @@ def parse_config_options( server_timeout: int, workers_per_resource: float, device: tuple[str, ...] | None, - environ: dict[str, t.Any], -) -> dict[str, t.Any]: + environ: DictStrAny, +) -> DictStrAny: _bentoml_config_options_env = environ.pop("BENTOML_CONFIG_OPTIONS", "") _bentoml_config_options_opts = [ "tracing.sample_rate=1.0", @@ -671,16 +658,13 @@ def parse_config_options( def start_command_factory( model_name_or_bento: str | bentoml.Bento, - _context_settings: dict[str, t.Any] | None = None, + _context_settings: DictStrAny | None = None, _serve_grpc: bool = False, ) -> click.Command: """Generate a 'click.Command' for any given LLM. Args: - model_name: The name of the model - factory: The click.Group to add the command to - _context_settings: The context settings to use for the command - _serve_grpc: Whether to serve the model via gRPC or HTTP + model_name_or_bento: The name of the model or the ``bentoml.Bento`` instance. Returns: The click.Command for starting the model server @@ -688,8 +672,7 @@ def start_command_factory( Note that the internal commands will return the llm_config and a boolean determine whether the server is run with GPU or not. """ - configure_logging() - group = start_cli if not _serve_grpc else start_grpc_cli + group = start_command if not _serve_grpc else start_grpc_command if isinstance(model_name_or_bento, bentoml.Bento): if "start_name" not in model_name_or_bento.info.labels: @@ -791,11 +774,13 @@ def prerequisite_check( ctx: click.Context, llm_config: openllm.LLMConfig, env: EnvVarMixin, - gpu_available: tuple[int, ...], + gpu_available: tuple[str, ...], quantize: t.LiteralString | None, adapter_map: dict[str, str | None] | None, num_workers: int, ) -> None: + if get_debug_mode(): + _echo("Running prerequisite check.", fg="magenta") if quantize: if len(gpu_available) < 1: _echo(f"Quantization requires at least 1 GPU (got {len(gpu_available)})", fg="red") @@ -829,6 +814,9 @@ def prerequisite_check( ) +_wpr_strategies = {"round_robin", "conserved"} + + def start_bento( group: click.Group, bento: bentoml.Bento, @@ -844,6 +832,20 @@ def start_bento( command_attrs["help"] = start_bento_docstring(bento, llm_config, serve_grpc) + # Now we have to format the model_id accordingly based on the model_fs + model_type = bento.info.labels["_type"] + model_framework = bento.info.labels["_framework"] + # the models should have the type + try: + model_store = ModelStore(bento._fs.opendir("models")) + model = model_store.get(f"{model_framework}-{model_type}") + except fs.errors.ResourceNotFound: + # new behaviour with BentoML models + _model_store = BentoMLContainer.model_store.get() + model = _model_store.get(f"{model_framework}-{model_type}") + except bentoml.exceptions.NotFound: + raise OpenLLMException(f"Failed to find models for {llm_config['start_name']}") from None + @group.command(**command_attrs) @start_decorator(llm_config, serve_grpc=serve_grpc) @click.pass_context @@ -851,15 +853,19 @@ def start_bento( ctx: click.Context, server_timeout: int | None, model_id: str | None, - workers_per_resource: float | None, + workers_per_resource: t.LiteralString | float | None, device: tuple[str, ...] | None, quantize: t.Literal["int8", "int4", "gptq"] | None, bettertransformer: bool | None, runtime: t.Literal["ggml", "transformers"], fast: bool, adapter_id: str | None, + return_process: bool, **attrs: t.Any, - ) -> openllm.LLMConfig: + ) -> openllm.LLMConfig | subprocess.Popen[bytes]: + if model_id is not None: + _echo("'model_id' has no effect when starting a BentoLLM", fg="yellow") + adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None) config, server_attrs = llm_config.model_validate_click(**attrs) @@ -873,6 +879,25 @@ def start_bento( server_attrs.setdefault("production", not development) workers_per_resource = first_not_none(workers_per_resource, default=config["workers_per_resource"]) + + if isinstance(workers_per_resource, str): + if workers_per_resource == "round_robin": + workers_per_resource = 1.0 + elif workers_per_resource == "conserved": + if device: + available_gpu = device + else: + available_gpu = gpu_count() + if len(available_gpu) != 0: + workers_per_resource = float(1 / len(available_gpu)) + else: + workers_per_resource = 1.0 + else: + try: + workers_per_resource = float(workers_per_resource) + except ValueError: + ctx.fail(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies.") + num_workers = int(1 / workers_per_resource) # Create a new model env to work with the envvar during CLI invocation @@ -896,7 +921,7 @@ def start_bento( env.runtime: env.runtime_value, "BENTOML_DEBUG": str(get_debug_mode()), "BENTOML_HOME": os.environ.get("BENTOML_HOME", BentoMLContainer.bentoml_home.get()), - "OPENLLM_MODEL_ID": model_id, + "OPENLLM_MODEL_ID": model.path, } ) @@ -915,8 +940,18 @@ def start_bento( server = bentoml.HTTPServer(bento, **server_attrs) analytics.track_start_init(config) + server.start(env=start_env, text=True) + process = server.process + assert process + + if return_process: + return process + try: - server.start(env=start_env, text=True, blocking=True) + assert process.stdout + with process: + for line in iter(process.stdout.readline, b""): + _echo(line.strip(), fg="white") except Exception as err: _echo(f"Error caught while starting LLM Server:\n{err}", fg="red") raise @@ -953,15 +988,16 @@ def start_model( ctx: click.Context, server_timeout: int | None, model_id: str | None, - workers_per_resource: float | None, + workers_per_resource: str | float | None, device: tuple[str, ...] | None, quantize: t.Literal["int8", "int4", "gptq"] | None, bettertransformer: bool | None, runtime: t.Literal["ggml", "transformers"], fast: bool, adapter_id: str | None, + return_process: bool, **attrs: t.Any, - ) -> openllm.LLMConfig: + ) -> openllm.LLMConfig | subprocess.Popen[bytes]: adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None) config, server_attrs = llm_config.model_validate_click(**attrs) @@ -975,6 +1011,25 @@ def start_model( server_attrs.setdefault("production", not development) workers_per_resource = first_not_none(workers_per_resource, default=config["workers_per_resource"]) + + if isinstance(workers_per_resource, str): + if workers_per_resource == "round_robin": + workers_per_resource = 1.0 + elif workers_per_resource == "conserved": + if device: + available_gpu = device + else: + available_gpu = gpu_count() + if len(available_gpu) != 0: + workers_per_resource = float(1 / len(available_gpu)) + else: + workers_per_resource = 1.0 + else: + try: + workers_per_resource = float(workers_per_resource) + except ValueError: + ctx.fail(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies.") + num_workers = int(1 / workers_per_resource) # Create a new model env to work with the envvar during CLI invocation @@ -1005,10 +1060,7 @@ def start_model( if adapter_map: _echo(f"OpenLLM will convert '{model_name}' to use provided adapters layers: {list(adapter_map)}") - llm = t.cast( - "_BaseAutoLLMClass", - openllm[env.framework_value], # type: ignore (internal API) - ).for_model( + llm = openllm.infer_auto_class(env.framework_value).for_model( model_name, model_id=model_id, llm_config=config, @@ -1027,7 +1079,6 @@ def start_model( "OPENLLM_MODEL": model_name, "OPENLLM_MODEL_ID": llm.model_id, "OPENLLM_ADAPTER_MAP": orjson.dumps(adapter_map).decode(), - "OPENLLM_SERVING": str(True), } ) @@ -1042,28 +1093,36 @@ def start_model( server = bentoml.HTTPServer("_service.py:svc", **server_attrs) analytics.track_start_init(llm.config) + server.start(env=start_env, text=True) + process = server.process + assert process + + if return_process: + return process + try: - server.start(env=start_env, text=True, blocking=True) + assert process.stdout + with process: + for line in iter(process.stdout.readline, b""): + _echo(line.strip(), fg="white") except Exception as err: _echo(f"Error caught while starting LLM Server:\n{err}", fg="red") raise - else: - if not get_debug_mode(): - cmd_name = f"openllm build {model_name}" - if adapter_map is not None: - cmd_name += " " + " ".join( - [ - f"--adapter-id {s}" - for s in [ - f"{p}:{name}" if name not in (None, "default") else p - for p, name in adapter_map.items() - ] + finally: + cmd_name = f"openllm build {model_name}" + if adapter_map is not None: + cmd_name += " " + " ".join( + [ + f"--adapter-id {s}" + for s in [ + f"{p}:{name}" if name not in (None, "default") else p for p, name in adapter_map.items() ] - ) - _echo( - f"\n🚀 Next step: run '{cmd_name}' to create a Bento for {model_name}", - fg="blue", + ] ) + _echo( + f"\n🚀 Next step: run '{cmd_name}' to create a Bento for {model_name}", + fg="blue", + ) # NOTE: Return the configuration for telemetry purposes. return config @@ -1071,7 +1130,7 @@ def start_model( return start_cmd -@cli.command(name="import", aliases=["download", "download-models"]) +@cli.command(name="import", aliases=["download"]) @click.argument( "model", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]), @@ -1093,14 +1152,14 @@ def start_model( @quantize_option(click) @click.option("--machine", is_flag=True, default=False, hidden=True) @click.option("--implementation", type=click.Choice(["pt", "tf", "flax"]), default=None, hidden=True) -def download_models( +def download_models_command( model: str, model_id: str | None, model_version: str | None, output: OutputLiteral, runtime: t.Literal["ggml", "transformers"], machine: bool, - implementation: t.Literal["pt", "tf", "flax"] | None, + implementation: LiteralRuntime | None, quantize: t.Literal["int8", "int4", "gptq"] | None, ): """Setup LLM interactively. @@ -1134,16 +1193,8 @@ def download_models( > only use this option if you want the weight to be quantized by default. Note that OpenLLM also > support on-demand quantisation during initial startup. """ - if output == "porcelain" or machine: - configure_logging() - if machine: - output = "porcelain" - - impl = first_not_none(implementation, default=EnvVarMixin(model).framework_value) - llm = t.cast( - "_BaseAutoLLMClass", - openllm[impl], # type: ignore - ).for_model( + impl: t.Literal["pt", "tf", "flax"] = first_not_none(implementation, default=EnvVarMixin(model).framework_value) + llm = openllm.infer_auto_class(impl).for_model( model, model_id=model_id, model_version=model_version, @@ -1158,7 +1209,7 @@ def download_models( _ref = bentoml.models.get(llm.tag) _previously_saved = True except bentoml.exceptions.NotFound: - if output == "pretty": + if not machine and output == "pretty": _echo( f"'{model}' with 'model_id={model_id}' does not exists in local store. Saving to store...", fg="yellow", @@ -1167,29 +1218,25 @@ def download_models( _ref = llm.import_model(trust_remote_code=llm.__llm_trust_remote_code__) - if machine: - # NOTE: When debug is enabled, - # We will prefix the tag with __tag__ and we can use regex to correctly - # get the tag from 'bentoml.bentos.build|build_bentofile' - _echo(f"__tag__:{_ref.tag}", fg="white") - elif output == "pretty": - if _previously_saved: + if not machine: + if output == "pretty": + if _previously_saved: + _echo( + f"{model} with 'model_id={model_id}' is already setup for framework '{impl}': {_ref.tag!s}", + nl=True, + fg="yellow", + ) + else: + _echo(f"Saved model: {_ref.tag}") + elif output == "json": _echo( - f"{model} with 'model_id={model_id}' is already setup for framework '{impl}': {str(_ref.tag)}", - nl=True, - fg="yellow", + orjson.dumps( + {"previously_setup": _previously_saved, "framework": impl, "tag": str(_ref.tag)}, + option=orjson.OPT_INDENT_2, + ).decode() ) else: - _echo(f"Saved model: {_ref.tag}") - elif output == "json": - _echo( - orjson.dumps( - {"previously_setup": _previously_saved, "framework": impl, "tag": str(_ref.tag)}, - option=orjson.OPT_INDENT_2, - ).decode() - ) - else: - _echo(_ref.tag) + _echo(_ref.tag) if is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache() @@ -1203,24 +1250,360 @@ _cached_grpc = { } +@overload def _start( model_name: str | bentoml.Bento, - framework: t.Literal["flax", "tf", "pt"] | None = None, - **attrs: t.Any, -): - """Python API to start a LLM server.""" - _serve_grpc = attrs.pop("_serve_grpc", False) + /, + model_id: str | None = ..., + timeout: int = ..., + workers_per_resource: t.Literal["conserved", "round_robin"] | float | None = ..., + device: tuple[str, ...] | t.Literal["all"] | None = ..., + quantize: t.Literal["int8", "int4", "gptq"] | None = ..., + bettertransformer: bool | None = ..., + runtime: t.Literal["ggml", "transformers"] = ..., + fast: bool = ..., + adapter_map: dict[t.LiteralString, str | None] | None = ..., + framework: t.Literal["flax", "tf", "pt"] | None = ..., + additional_args: ListStr | None = ..., + _serve_grpc: bool = ..., + __test__: t.Literal[False] = ..., +) -> openllm.LLMConfig: + ... + +@overload +def _start( + model_name: str | bentoml.Bento, + /, + model_id: str | None = ..., + timeout: int = ..., + workers_per_resource: t.Literal["conserved", "round_robin"] | float | None = ..., + device: tuple[str, ...] | t.Literal["all"] | None = ..., + quantize: t.Literal["int8", "int4", "gptq"] | None = ..., + bettertransformer: bool | None = ..., + runtime: t.Literal["ggml", "transformers"] = ..., + fast: bool = ..., + adapter_map: dict[t.LiteralString, str | None] | None = ..., + framework: t.Literal["flax", "tf", "pt"] | None = ..., + additional_args: ListStr | None = ..., + _serve_grpc: bool = ..., + __test__: t.Literal[True] = ..., +) -> subprocess.Popen[bytes]: + ... + + +def _start( + model_name: str | bentoml.Bento, + /, + model_id: str | None = None, + timeout: int = 30, + workers_per_resource: t.Literal["conserved", "round_robin"] | float | None = None, + device: tuple[str, ...] | t.Literal["all"] | None = None, + quantize: t.Literal["int8", "int4", "gptq"] | None = None, + bettertransformer: bool | None = None, + runtime: t.Literal["ggml", "transformers"] = "transformers", + fast: bool = False, + adapter_map: dict[t.LiteralString, str | None] | None = None, + framework: t.Literal["flax", "tf", "pt"] | None = None, + additional_args: ListStr | None = None, + _serve_grpc: bool = False, + __test__: bool = False, +) -> openllm.LLMConfig | subprocess.Popen[bytes]: + """Python API to start a LLM server. These provides one-to-one mapping to CLI arguments. + + For all additional arguments, pass it as string to ``additional_args``. For example, if you want to + pass ``--port 5001``, you can pass ``additional_args=["--port", "5001"]`` + + > **Note**: This will create a blocking process, so if you use this API, you can create a running sub thread + > to start the server instead of blocking the main thread. + + ``openllm.start`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI interaction. + + > **Note**: ``quantize`` and ``bettertransformer`` are mutually exclusive. + + Args: + model_name: The model name to start this LLM + model_id: Optional model id for this given LLM + timeout: The server timeout + workers_per_resource: Number of workers per resource assigned. + See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy + for more information. By default, this is set to 1. + + > **Note**: ``--workers-per-resource`` will also accept the following strategies: + + > - ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models. + + > - ``conserved``: Thjis will determine the number of available GPU resources, and only assign + one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is + equivalent to ``--workers-per-resource 0.25``. + device: Assign GPU devices (if available) to this LLM. By default, this is set to ``None``. It also accepts 'all' + argument to assign all available GPUs to this LLM. + quantize: Quantize the model weights. This is only applicable for PyTorch models. + Possible quantisation strategies: + - int8: Quantize the model with 8bit (bitsandbytes required) + - int4: Quantize the model with 4bit (bitsandbytes required) + - gptq: Quantize the model with GPTQ (autogptq required) + bettertransformer: Convert given model to FastTransformer with PyTorch. + runtime: The runtime to use for this LLM. By default, this is set to ``transformers``. In the future, this will include supports for GGML. + fast: Enable fast mode. This will skip downloading models, and will raise errors if given model_id does not exists under local store. + adapter_map: The adapter mapping of LoRA to use for this LLM. It accepts a dictionary of ``{adapter_id: adapter_name}``. + framework: The framework to use for this LLM. By default, this is set to ``pt``. + additional_args: Additional arguments to pass to ``openllm start``. + """ if isinstance(model_name, str): _ModelEnv = EnvVarMixin(model_name) if framework is None: framework = _ModelEnv.framework_value os.environ[_ModelEnv.framework] = framework - start_command_factory(model_name, _serve_grpc=_serve_grpc)(standalone_mode=False, **attrs) + + args: ListStr = ["--runtime", runtime] + if model_id: + if isinstance(model_id, bentoml.Bento): + logger.warning("'model_id' has no effect if since %s is already a Bento.", model_name) + else: + args.extend(["--model-id", model_id]) + if timeout: + args.extend(["--server-timeout", str(timeout)]) + if workers_per_resource: + args.extend( + [ + "--workers-per-resource", + str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource, + ] + ) + if device and not os.getenv("CUDA_VISIBLE_DEVICES"): + args.extend(["--device", ",".join(device)]) + + if quantize and bettertransformer: + raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.") + + if quantize: + args.extend(["--quantize", str(quantize)]) + if bettertransformer: + args.append("--bettertransformer") + if fast: + args.append("--fast") + if adapter_map: + args.extend( + list( + itertools.chain.from_iterable( + [["--adapter-id", f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()] + ) + ) + ) + if additional_args: + args.extend(additional_args) + + if __test__: + args.append("--return-process") + + return start_command_factory(model_name, _context_settings=_CONTEXT_SETTINGS, _serve_grpc=_serve_grpc).main( + args=args if len(args) > 0 else None, + standalone_mode=False, + ) -start = functools.partial(_start, _serve_grpc=False) -start_grpc = functools.partial(_start, _serve_grpc=True) +@overload +def _build( + model_name: str, + /, + *, + model_id: str | None = ..., + model_version: str | None = ..., + quantize: t.Literal["int8", "int4", "gptq"] | None = ..., + bettertransformer: bool | None = ..., + adapter_map: dict[str, str | None] | None = ..., + build_ctx: str | None = ..., + enable_features: tuple[str, ...] | None = ..., + workers_per_resource: int | float | None = ..., + runtime: t.Literal["ggml", "transformers"] = ..., + dockerfile_template: str | None = ..., + overwrite: bool = ..., + format: t.Literal["bento"] = "bento", + additional_args: list[str] | None = ..., +) -> bentoml.Bento: + ... + + +@overload +def _build( + model_name: str, + /, + *, + model_id: str | None = ..., + model_version: str | None = ..., + quantize: t.Literal["int8", "int4", "gptq"] | None = ..., + bettertransformer: bool | None = ..., + adapter_map: dict[str, str | None] | None = ..., + build_ctx: str | None = ..., + enable_features: tuple[str, ...] | None = ..., + workers_per_resource: int | float | None = ..., + runtime: t.Literal["ggml", "transformers"] = ..., + dockerfile_template: str | None = ..., + overwrite: bool = ..., + format: t.Literal["container"] = ..., + additional_args: list[str] | None = ..., +) -> str: + ... + + +def _build( + model_name: str, + /, + *, + model_id: str | None = None, + model_version: str | None = None, + quantize: t.Literal["int8", "int4", "gptq"] | None = None, + bettertransformer: bool | None = None, + adapter_map: dict[str, str | None] | None = None, + build_ctx: str | None = None, + enable_features: tuple[str, ...] | None = None, + workers_per_resource: int | float | None = None, + runtime: t.Literal["ggml", "transformers"] = "transformers", + dockerfile_template: str | None = None, + overwrite: bool = False, + format: t.Literal["bento", "container"] = "bento", + additional_args: list[str] | None = None, +) -> bentoml.Bento | str: + """Package a LLM into a Bento. + + The LLM will be built into a BentoService with the following structure: + if ``quantize`` is passed, it will instruct the model to be quantized dynamically during serving time. + if ``bettertransformer`` is passed, it will instruct the model to apply FasterTransformer during serving time. + + ``openllm.build`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as ``openllm build`` CLI. + + > **Note**: ``quantize`` and ``bettertransformer`` are mutually exclusive. + + Args: + model_name: The model name to start this LLM + model_id: Optional model id for this given LLM + model_version: Optional model version for this given LLM + quantize: Quantize the model weights. This is only applicable for PyTorch models. + Possible quantisation strategies: + - int8: Quantize the model with 8bit (bitsandbytes required) + - int4: Quantize the model with 4bit (bitsandbytes required) + - gptq: Quantize the model with GPTQ (autogptq required) + bettertransformer: Convert given model to FastTransformer with PyTorch. + adapter_map: The adapter mapping of LoRA to use for this LLM. It accepts a dictionary of ``{adapter_id: adapter_name}``. + build_ctx: The build context to use for building BentoLLM. By default, it sets to current directory. + enable_features: Additional OpenLLM features to be included with this BentoLLM. + workers_per_resource: Number of workers per resource assigned. + See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy + for more information. By default, this is set to 1. + + > **Note**: ``--workers-per-resource`` will also accept the following strategies: + + > - ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models. + + > - ``conserved``: This will determine the number of available GPU resources, and only assign + one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is + equivalent to ``--workers-per-resource 0.25``. + runtime: The runtime to use for this LLM. By default, this is set to ``transformers``. In the future, this will include supports for GGML. + dockerfile_template: The dockerfile template to use for building BentoLLM. See + https://docs.bentoml.com/en/latest/guides/containerization.html#dockerfile-template. + overwrite: Whether to overwrite the existing BentoLLM. By default, this is set to ``False``. + format: The output format to build this LLM. By default it will build the BentoLLM. 'container' is equivalent of 'openllm build && bentoml containerize ' + additional_args: Additional arguments to pass to ``openllm build``. + + Returns: + ``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud. + If 'format="container"', then it returns the default 'container_name:container_tag' + """ + args: ListStr = [model_name, "--runtime", runtime, "--format", format] + + if quantize and bettertransformer: + raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.") + + if quantize: + args.extend(["--quantize", quantize]) + if bettertransformer: + args.append("--bettertransformer") + + if model_id: + args.extend(["--model-id", model_id]) + if build_ctx: + args.extend(["--build-ctx", build_ctx]) + if enable_features: + args.extend([f"--enable-features={f}" for f in enable_features]) + if workers_per_resource: + args.extend(["--workers-per-resource", str(workers_per_resource)]) + if overwrite: + args.append("--overwrite") + if adapter_map: + args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()]) + if model_version: + args.extend(["--model-version", model_version]) + if dockerfile_template: + args.extend(["--dockerfile-template", dockerfile_template]) + if additional_args: + args.extend(additional_args) + + return build_command.main(args=args, standalone_mode=False) + + +def _import_model( + model_name: str, + /, + *, + model_id: str | None = None, + model_version: str | None = None, + runtime: t.Literal["ggml", "transformers"] = "transformers", + implementation: LiteralRuntime = "pt", + quantize: t.Literal["int8", "int4", "gptq"] | None = None, + additional_args: t.Sequence[str] | None = None, +) -> bentoml.Model: + """Import a LLM into local store. + + > **Note**: If ``quantize`` is passed, the model weights will be saved as quantized weights. You should + > only use this option if you want the weight to be quantized by default. Note that OpenLLM also + > support on-demand quantisation during initial startup. + + ``openllm.download`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI ``openllm import``. + + > **Note**: ``openllm.start`` will automatically invoke ``openllm.download`` under the hood. + + Args: + model_name: The model name to start this LLM + model_id: Optional model id for this given LLM + model_version: Optional model version for this given LLM + runtime: The runtime to use for this LLM. By default, this is set to ``transformers``. In the future, this will include supports for GGML. + implementation: The implementation to use for this LLM. By default, this is set to ``pt``. + quantize: Quantize the model weights. This is only applicable for PyTorch models. + Possible quantisation strategies: + - int8: Quantize the model with 8bit (bitsandbytes required) + - int4: Quantize the model with 4bit (bitsandbytes required) + - gptq: Quantize the model with GPTQ (autogptq required) + additional_args: Additional arguments to pass to ``openllm import``. + + Returns: + ``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud. + """ + args = [model_name, "--runtime", runtime, "--implementation", implementation, "--machine"] + if model_id is not None: + args.append(model_id) + if model_version is not None: + args.extend(["--model-version", str(model_version)]) + if additional_args is not None: + args.extend(additional_args) + if quantize is not None: + args.extend(["--quantize", quantize]) + return download_models_command.main(args=args, standalone_mode=False) + + +def _list_models() -> DictStrAny: + """List all available models within the local store.""" + args = ["-o", "json", "--show-available", "--machine"] + return models_command.main(args=args, standalone_mode=False) + + +start, start_grpc, build, import_model, list_models = ( + codegen.gen_sdk(_start, _serve_grpc=False), + codegen.gen_sdk(_start, _serve_grpc=True), + codegen.gen_sdk(_build), + codegen.gen_sdk(_import_model), + codegen.gen_sdk(_list_models), +) @cli.command(context_settings={"token_normalize_func": inflection.underscore}) @@ -1229,7 +1612,6 @@ start_grpc = functools.partial(_start, _serve_grpc=True) ) @model_id_option(click) @output_option -@click.option("--machine", is_flag=True, default=False, hidden=True) @click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.") @workers_per_resource_option(click, build=True) @cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name="Optimisation options.") @@ -1270,8 +1652,15 @@ start_grpc = functools.partial(_start, _serve_grpc=True) type=click.File(), help="Optional custom dockerfile template to be used with this BentoLLM.", ) +@click.option( + "--format", + default="bento", + type=click.Choice(["bento", "container"]), + help="The output format for 'openllm build'. By default this will build a BentoLLM. 'container' is the shortcut of 'openllm build && bentoml containerize'.", + hidden=not get_debug_mode(), +) @click.pass_context -def build( +def build_command( ctx: click.Context, model_name: str, model_id: str | None, @@ -1284,9 +1673,9 @@ def build( workers_per_resource: float | None, adapter_id: tuple[str, ...], build_ctx: str | None, - machine: bool, model_version: str | None, dockerfile_template: t.TextIO | None, + format: t.Literal["bento", "container"], **attrs: t.Any, ): """Package a given models into a Bento. @@ -1300,6 +1689,8 @@ def build( > NOTE: To run a container built from this Bento with GPU support, make sure > to have https://github.com/NVIDIA/nvidia-container-toolkit install locally. """ + from bentoml_cli.cli import cli + from ._package import create_bento adapter_map: dict[str, str | None] | None = None @@ -1316,11 +1707,6 @@ def build( # we are just doing the parsing here. adapter_map[_adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None - if output == "porcelain" or machine: - configure_logging() - if machine: - output = "porcelain" - if output == "pretty": if overwrite: _echo(f"Overwriting existing Bento for {model_name}.", fg="yellow") @@ -1335,8 +1721,6 @@ def build( llm_config = openllm.AutoConfig.for_model(model_name) - logger.info("Packing '%s' into a Bento%s...", model_name, f" with 'kwargs={attrs}' " if attrs else "") - # NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError # during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path try: @@ -1345,10 +1729,7 @@ def build( os.environ["OPENLLM_ADAPTER_MAP"] = orjson.dumps(adapter_map).decode() framework_envvar = llm_config["env"].framework_value - llm = t.cast( - "_BaseAutoLLMClass", - openllm[framework_envvar], # type: ignore (internal API) - ).for_model( + llm = openllm.infer_auto_class(framework_envvar).for_model( model_name, model_id=model_id, llm_config=llm_config, @@ -1424,12 +1805,7 @@ def build( if current_adapter_map_envvar is not None: os.environ["OPENLLM_ADAPTER_MAP"] = current_adapter_map_envvar - if machine: - # NOTE: When debug is enabled, - # We will prefix the tag with __tag__ and we can use regex to correctly - # get the tag from 'bentoml.bentos.build|build_bentofile' - _echo(f"__tag__:{bento.tag}", fg="white") - elif output == "pretty": + if output == "pretty": if not get_quiet_mode(): _echo("\n" + OPENLLM_FIGLET, fg="white") if not _previously_built: @@ -1441,11 +1817,17 @@ def build( ) _echo( - "\nPossible next steps:\n\n" - + "* Push to BentoCloud with `bentoml push`:\n" + "📖 Next steps:\n\n" + + "* Serving BentoLLM locally with 'openllm start':\n" + + f" $ openllm start {bento.tag}\n\n" + + "* Push to BentoCloud with 'bentoml push':\n" + f" $ bentoml push {bento.tag}\n\n" - + "* Containerize your Bento with `bentoml containerize`:\n" - + f" $ bentoml containerize {bento.tag}\n\n" + + "* Containerize your Bento with 'bentoml containerize':\n" + + f" $ bentoml containerize {bento.tag}" + + " --opt progress=plain" + if get_debug_mode() + else "" + + "\n\n" + " Tip: To enable additional BentoML features for 'containerize', " + "use '--enable-features=FEATURE[,FEATURE]' " + "[see 'bentoml containerize -h' for more advanced usage]\n", @@ -1456,7 +1838,30 @@ def build( else: _echo(bento.tag) - return bento + if format == "bento": + return bento + + backend = os.getenv("BENTOML_CONTAINERIZE_BACKEND", "docker") + _echo(f"\nBuilding {bento} into a LLMContainer using backend '{backend}'", fg="magenta") + args = ["--backend", backend] + if get_debug_mode(): + args.extend([str(bento.tag), "--opt", "progress=plain"]) + cli.commands["containerize"].main(standalone_mode=False, args=args) + return str(bento.tag) + + +@overload +def models_command( + ctx: click.Context, output: OutputLiteral, show_available: bool, machine: t.Literal[True] = True +) -> DictStrAny: + ... + + +@overload +def models_command( + ctx: click.Context, output: OutputLiteral, show_available: bool, machine: t.Literal[False] = ... +) -> None: + ... @cli.command() @@ -1467,8 +1872,11 @@ def build( default=False, help="Show available models in local store (mutually exclusive with '-o porcelain').", ) +@click.option("--machine", is_flag=True, default=False, hidden=True) @click.pass_context -def models(ctx: click.Context, output: OutputLiteral, show_available: bool): +def models_command( + ctx: click.Context, output: OutputLiteral, show_available: bool, machine: bool +) -> DictStrAny | None: """List all supported models. \b @@ -1498,7 +1906,7 @@ def models(ctx: click.Context, output: OutputLiteral, show_available: bool): converted: list[str] = [] for m in models: config = openllm.AutoConfig.for_model(m) - runtime_impl: tuple[t.Literal["pt", "flax", "tf"] | str, ...] = () + runtime_impl: tuple[str, ...] = () if config["model_name"] in openllm.MODEL_MAPPING_NAMES: runtime_impl += ("pt",) if config["model_name"] in openllm.MODEL_FLAX_MAPPING_NAMES: @@ -1527,7 +1935,13 @@ def models(ctx: click.Context, output: OutputLiteral, show_available: bool): ids_in_local_store = {k: [i for i in bentoml.models.list() if k in i.tag.name] for k in json_data.keys()} ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v} - if output == "pretty": + if machine: + dumped: DictStrAny = json_data + if show_available: + assert ids_in_local_store + dumped["local"] = [bentoml_cattr.unstructure(i.tag) for m in ids_in_local_store.values() for i in m] + return dumped + elif output == "pretty": import tabulate tabulate.PRESERVE_WHITESPACE = True @@ -1542,7 +1956,7 @@ def models(ctx: click.Context, output: OutputLiteral, show_available: bool): str, t.LiteralString, t.LiteralString, - tuple[t.Literal["pt", "flax", "tf"], ...], + tuple[LiteralRuntime, ...], ] ] = [] for m, v in json_data.items(): @@ -1619,7 +2033,7 @@ def models(ctx: click.Context, output: OutputLiteral, show_available: bool): ) _echo(formatted_table, fg="white") else: - dumped: dict[str, t.Any] = json_data + dumped: DictStrAny = json_data if show_available: assert ids_in_local_store dumped["local"] = [bentoml_cattr.unstructure(i.tag) for m in ids_in_local_store.values() for i in m] @@ -1642,7 +2056,7 @@ def models(ctx: click.Context, output: OutputLiteral, show_available: bool): help="Skip confirmation when deleting a specific model", ) @inject -def prune(yes: bool, model_store: ModelStore = Provide[BentoMLContainer.model_store]): +def prune_command(yes: bool, model_store: ModelStore = Provide[BentoMLContainer.model_store]): """Remove all saved models locally.""" available = [ m @@ -1671,8 +2085,6 @@ def parsing_instruction_callback( if isinstance(value, list): # we only parse --text foo bar -> --text foo and omit bar value = value[-1] - if not isinstance(value, str): - raise click.BadParameter(f"Invalid option format: {value}") key, *values = value.split("=") if not key.startswith("--"): @@ -1736,10 +2148,10 @@ def instruct( output: OutputLiteral, remote: bool, task: str, - _memoized: dict[str, t.Any], + _memoized: DictStrAny, **attrs: t.Any, ): - """Instruct agents interactively for given tasks, from a terminal + """Instruct agents interactively for given tasks, from a terminal. \b ```bash @@ -1804,11 +2216,6 @@ def query( input_fg = "yellow" generated_fg = "cyan" - model = t.cast( - "_BaseAutoLLMClass", - openllm[client.framework], # type: ignore (internal API) - ).for_model(client.model_name) - if output != "porcelain": _echo("Input prompt: ", nl=False, fg="white") _echo(f"{prompt}", fg="magenta", nl=False) @@ -1816,18 +2223,17 @@ def query( res = client.query(prompt, return_raw_response=True) if output == "pretty": - formatted = model.postprocess_generate(prompt, res["responses"]) - generated = formatted[len(prompt) :] + formatted = client.llm.postprocess_generate(prompt, res["responses"]) _echo("\n\n==Responses==\n", fg="white") - _echo(formatted[: len(prompt)], fg=input_fg, nl=False) - _echo(generated, fg=generated_fg) + _echo(f"{prompt} ", fg=input_fg, nl=False) + _echo(formatted, fg=generated_fg) elif output == "json": _echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg="white") else: _echo(res["responses"], fg="white") -def load_notebook_metadata() -> dict[str, t.Any]: +def load_notebook_metadata() -> DictStrAny: with open(os.path.join(os.path.dirname(openllm.playground.__file__), "_meta.yml"), "r") as f: content = yaml.safe_load(f) if not all("description" in k for k in content.values()): @@ -1846,7 +2252,7 @@ def load_notebook_metadata() -> dict[str, t.Any]: help="Default port for Jupyter server", ) def playground(output_dir: str | None, port: int): - """OpenLLM Playground + """OpenLLM Playground. A collections of notebooks to explore the capabilities of OpenLLM. This includes notebooks for fine-tuning, inference, and more. @@ -1889,18 +2295,29 @@ def playground(output_dir: str | None, port: int): continue _echo("Generating notebook for: " + module.name, fg="magenta") markdown_cell = nbformat.v4.new_markdown_cell(metadata[module.name]["description"]) - f = jupytext.read(os.path.join(module.module_finder.path, module.name + ".py")) + f = jupytext.read(os.path.join(module.module_finder.path, module.name + ".py")) # type: ignore f.cells.insert(0, markdown_cell) jupytext.write(f, os.path.join(output_dir, module.name + ".ipynb"), fmt="notebook") try: subprocess.check_output( - ["jupyter", "notebook", "--notebook-dir", output_dir, "--port", str(port), "--no-browser", "--debug"] + [ + sys.executable, + "-m", + "jupyter", + "notebook", + "--notebook-dir", + output_dir, + "--port", + str(port), + "--no-browser", + "--debug", + ] ) except subprocess.CalledProcessError as e: _echo(e.output, fg="red") raise e except KeyboardInterrupt: - _echo("Shutting down Jupyter server...", fg="yellow") + _echo("\nShutting down Jupyter server...", fg="yellow") if _temp_dir: _echo("Note: You can access the generated notebooks in: " + output_dir, fg="blue") diff --git a/src/openllm/client.py b/src/openllm/client.py index 24f5ecca..f2fb141c 100644 --- a/src/openllm/client.py +++ b/src/openllm/client.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -OpenLLM client. +"""OpenLLM client. To start interact with the server, you can do the following: @@ -21,7 +20,6 @@ To start interact with the server, you can do the following: >>> client.query("What is the meaning of life?") """ from __future__ import annotations - import importlib import itertools import typing as t diff --git a/src/openllm/exceptions.py b/src/openllm/exceptions.py index c386c28d..59a99b99 100644 --- a/src/openllm/exceptions.py +++ b/src/openllm/exceptions.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Base exceptions for OpenLLM. This extends BentoML exceptions. -""" +"""Base exceptions for OpenLLM. This extends BentoML exceptions.""" from __future__ import annotations import bentoml diff --git a/src/openllm/models/auto/__init__.py b/src/openllm/models/auto/__init__.py index 86cbbcce..9bd50744 100644 --- a/src/openllm/models/auto/__init__.py +++ b/src/openllm/models/auto/__init__.py @@ -15,12 +15,14 @@ """This module is derived from HuggingFace's AutoConfig, AutoModel, etc.""" from __future__ import annotations - import typing as t import openllm -from ...utils import is_torch_available, is_flax_available, is_tf_available, LazyModule +from ...utils import LazyModule +from ...utils import is_flax_available +from ...utils import is_tf_available +from ...utils import is_torch_available _import_structure = { diff --git a/src/openllm/models/auto/configuration_auto.py b/src/openllm/models/auto/configuration_auto.py index 1a6014fa..86a0c749 100644 --- a/src/openllm/models/auto/configuration_auto.py +++ b/src/openllm/models/auto/configuration_auto.py @@ -13,8 +13,6 @@ # limitations under the License. from __future__ import annotations - -import types import typing as t from collections import OrderedDict @@ -24,6 +22,7 @@ import openllm if t.TYPE_CHECKING: + import types from collections import _odict_items from collections import _odict_keys from collections import _odict_values @@ -93,9 +92,7 @@ class _LazyConfigMapping(ConfigOrderedDict): return item in self._mapping or item in self._extra_content def register(self, key: str, value: t.Any): - """ - Register a new configuration in this mapping. - """ + """Register a new configuration in this mapping.""" if key in self._mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.") self._extra_content[key] = value @@ -115,7 +112,10 @@ CONFIG_NAME_ALIASES: dict[str, str] = { class AutoConfig: def __init__(self, *_: t.Any, **__: t.Any): - raise EnvironmentError("Cannot instantiate Config. Please use `Config.for_model(model_name)` instead.") + """This metaclass should be initialised via `AutoConfig.for_model`.""" + raise EnvironmentError( + "Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead." + ) @classmethod def for_model(cls, model_name: str, **attrs: t.Any) -> openllm.LLMConfig: diff --git a/src/openllm/models/auto/factory.py b/src/openllm/models/auto/factory.py index bf3bbbf4..f5291a8b 100644 --- a/src/openllm/models/auto/factory.py +++ b/src/openllm/models/auto/factory.py @@ -13,11 +13,10 @@ # limitations under the License. from __future__ import annotations - import importlib import inspect import logging -import types +import sys import typing as t from collections import OrderedDict @@ -30,13 +29,14 @@ from .configuration_auto import AutoConfig # NOTE: We need to do this so that overload can register # correct overloads to typing registry -if hasattr(t, "get_overloads"): +if sys.version_info[:2] >= (3, 11): from typing import overload else: from typing_extensions import overload if t.TYPE_CHECKING: + import types from collections import _odict_items from collections import _odict_keys from collections import _odict_values @@ -54,10 +54,10 @@ else: logger = logging.getLogger(__name__) -class _BaseAutoLLMClass: +class BaseAutoLLMClass: _model_mapping: _LazyAutoMapping - def __init__(self, *args: t.Any, **attrs: t.Any): + def __init__(self, *args: t.Any, **attrs: t.Any): # noqa raise EnvironmentError( f"Cannot instantiate {self.__class__.__name__} directly. " "Please use '{self.__class__.__name__}.Runner(model_name)' instead." @@ -129,7 +129,7 @@ class _BaseAutoLLMClass: llm = model_class.from_pretrained(model_id, model_version=model_version, llm_config=llm_config, **attrs) if ensure_available: logger.debug( - "'ensure_available=True', Downloading '%s' with 'model_id=%s' to local model store.", + "'ensure_available=True', OpenLLM will automatically import '%s' with 'model_id=%s' to local store if the entry does not exists.", model, llm.model_id, ) @@ -144,8 +144,7 @@ class _BaseAutoLLMClass: @classmethod def create_runner(cls, model: str, model_id: str | None = None, **attrs: t.Any) -> LLMRunner: - """ - Create a LLM Runner for the given model name. + """Create a LLM Runner for the given model name. Args: model: The model name to instantiate. @@ -160,8 +159,7 @@ class _BaseAutoLLMClass: @classmethod def register(cls, config_class: type[openllm.LLMConfig], llm_class: type[openllm.LLM[t.Any, t.Any]]): - """ - Register a new model for this class. + """Register a new model for this class. Args: config_class: The configuration corresponding to the model to register. @@ -191,13 +189,14 @@ def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any: try: return getattribute_from_module(openllm_module, attr) except ValueError: - raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") + raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None else: raise ValueError(f"Could not find {attr} in {openllm_module}!") class _LazyAutoMapping(ConfigModelOrderedDict): - """Based on transformers.models.auto.configuration_auto._LazyAutoMapping + """Based on transformers.models.auto.configuration_auto._LazyAutoMapping. + This OrderedDict values() and keys() returns the list instead, so you don't have to do list(mapping.values()) to get the list of values. """ @@ -281,9 +280,7 @@ class _LazyAutoMapping(ConfigModelOrderedDict): return model_type in self._model_mapping def register(self, key: t.Any, value: t.Any): - """ - Register a new model in this mapping. - """ + """Register a new model in this mapping.""" if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping: model_type = self._reverse_config_mapping[key.__name__] if model_type in self._model_mapping.keys(): @@ -292,4 +289,4 @@ class _LazyAutoMapping(ConfigModelOrderedDict): self._extra_content[key] = value -__all__ = ["_BaseAutoLLMClass", "_LazyAutoMapping"] +__all__ = ["BaseAutoLLMClass", "_LazyAutoMapping"] diff --git a/src/openllm/models/auto/modeling_auto.py b/src/openllm/models/auto/modeling_auto.py index be477308..3dcfc402 100644 --- a/src/openllm/models/auto/modeling_auto.py +++ b/src/openllm/models/auto/modeling_auto.py @@ -13,12 +13,11 @@ # limitations under the License. from __future__ import annotations - import typing as t from collections import OrderedDict from .configuration_auto import CONFIG_MAPPING_NAMES -from .factory import _BaseAutoLLMClass +from .factory import BaseAutoLLMClass from .factory import _LazyAutoMapping @@ -45,5 +44,5 @@ MODEL_MAPPING: dict[ ] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) -class AutoLLM(_BaseAutoLLMClass): +class AutoLLM(BaseAutoLLMClass): _model_mapping = MODEL_MAPPING diff --git a/src/openllm/models/auto/modeling_flax_auto.py b/src/openllm/models/auto/modeling_flax_auto.py index cb134c8b..30d9b917 100644 --- a/src/openllm/models/auto/modeling_flax_auto.py +++ b/src/openllm/models/auto/modeling_flax_auto.py @@ -13,12 +13,11 @@ # limitations under the License. from __future__ import annotations - import typing as t from collections import OrderedDict from .configuration_auto import CONFIG_MAPPING_NAMES -from .factory import _BaseAutoLLMClass +from .factory import BaseAutoLLMClass from .factory import _LazyAutoMapping @@ -38,5 +37,5 @@ MODEL_FLAX_MAPPING: dict[ ] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES) -class AutoFlaxLLM(_BaseAutoLLMClass): +class AutoFlaxLLM(BaseAutoLLMClass): _model_mapping = MODEL_FLAX_MAPPING diff --git a/src/openllm/models/auto/modeling_tf_auto.py b/src/openllm/models/auto/modeling_tf_auto.py index cd86d207..99494b4d 100644 --- a/src/openllm/models/auto/modeling_tf_auto.py +++ b/src/openllm/models/auto/modeling_tf_auto.py @@ -13,12 +13,11 @@ # limitations under the License. from __future__ import annotations - import typing as t from collections import OrderedDict from .configuration_auto import CONFIG_MAPPING_NAMES -from .factory import _BaseAutoLLMClass +from .factory import BaseAutoLLMClass from .factory import _LazyAutoMapping @@ -38,5 +37,5 @@ MODEL_TF_MAPPING: dict[ ] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES) -class AutoTFLLM(_BaseAutoLLMClass): +class AutoTFLLM(BaseAutoLLMClass): _model_mapping = MODEL_TF_MAPPING diff --git a/src/openllm/models/chatglm/__init__.py b/src/openllm/models/chatglm/__init__.py index 193bd476..0d5513cc 100644 --- a/src/openllm/models/chatglm/__init__.py +++ b/src/openllm/models/chatglm/__init__.py @@ -13,11 +13,12 @@ # limitations under the License. from __future__ import annotations - import typing as t -from ...utils import is_torch_available, is_cpm_kernels_available, LazyModule from ...exceptions import MissingDependencyError +from ...utils import LazyModule +from ...utils import is_cpm_kernels_available +from ...utils import is_torch_available _import_structure = { diff --git a/src/openllm/models/chatglm/configuration_chatglm.py b/src/openllm/models/chatglm/configuration_chatglm.py index 0ee7e054..82f0ab21 100644 --- a/src/openllm/models/chatglm/configuration_chatglm.py +++ b/src/openllm/models/chatglm/configuration_chatglm.py @@ -17,9 +17,7 @@ import openllm class ChatGLMConfig(openllm.LLMConfig): - """ - ChatGLM is an open bilingual language model based on - [General Language Model (GLM)](https://github.com/THUDM/GLM) framework. + """ChatGLM is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). diff --git a/src/openllm/models/chatglm/modeling_chatglm.py b/src/openllm/models/chatglm/modeling_chatglm.py index 678c6c6e..b74809ad 100644 --- a/src/openllm/models/chatglm/modeling_chatglm.py +++ b/src/openllm/models/chatglm/modeling_chatglm.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations - import typing as t import bentoml import openllm +from ...utils import generate_labels + if t.TYPE_CHECKING: import torch + import transformers else: torch = openllm.utils.LazyLoader("torch", globals(), "torch") @@ -34,15 +36,15 @@ class ChatGLM(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrain self.device = torch.device("cuda") def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model: - (_, model_attrs), tokenizer_kwds = self.llm_parameters - attrs = {**model_attrs, **attrs} + _, tokenizer_attrs = self.llm_parameters return bentoml.transformers.save_model( self.tag, transformers.AutoModel.from_pretrained(self.model_id, trust_remote_code=trust_remote_code), + labels=generate_labels(self), custom_objects={ "tokenizer": transformers.AutoTokenizer.from_pretrained( - self.model_id, trust_remote_code=trust_remote_code, **tokenizer_kwds + self.model_id, trust_remote_code=trust_remote_code, **tokenizer_attrs ) }, ) @@ -62,8 +64,8 @@ class ChatGLM(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrain if use_default_prompt_template and chat_history is not None: for i, (old_query, response) in enumerate(chat_history): - prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n" - prompt_text += f"[Round {len(chat_history)}]\n问:{prompt}\n答:" + prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n" # noqa: RUF001 + prompt_text += f"[Round {len(chat_history)}]\n问:{prompt}\n答:" # noqa: RUF001 else: prompt_text = prompt diff --git a/src/openllm/models/dolly_v2/__init__.py b/src/openllm/models/dolly_v2/__init__.py index 70454aec..ad8376d3 100644 --- a/src/openllm/models/dolly_v2/__init__.py +++ b/src/openllm/models/dolly_v2/__init__.py @@ -13,11 +13,12 @@ # limitations under the License. from __future__ import annotations - import typing as t -from ...utils import is_torch_available, LazyModule from ...exceptions import MissingDependencyError +from ...utils import LazyModule +from ...utils import is_torch_available + _import_structure = { "configuration_dolly_v2": ["DollyV2Config", "START_DOLLY_V2_COMMAND_DOCSTRING", "DEFAULT_PROMPT_TEMPLATE"], diff --git a/src/openllm/models/dolly_v2/configuration_dolly_v2.py b/src/openllm/models/dolly_v2/configuration_dolly_v2.py index 5425a79a..ad3e9529 100644 --- a/src/openllm/models/dolly_v2/configuration_dolly_v2.py +++ b/src/openllm/models/dolly_v2/configuration_dolly_v2.py @@ -11,12 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -The following includes OpenLLM configuration and excerpt from -[instruct_pipeline.py](https://huggingface.co/databricks/dolly-v2-3b/blob/main/instruct_pipeline.py) -""" from __future__ import annotations - import typing as t import openllm @@ -27,8 +22,7 @@ if t.TYPE_CHECKING: class DollyV2Config(openllm.LLMConfig): - """Databricks’ Dolly is an instruction-following large language model trained on the Databricks - machine learning platform that is licensed for commercial use. + """Databricks` Dolly is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use. Based on pythia-12b, Dolly is trained on ~15k instruction/response fine tuning records databricks-dolly-15k generated by Databricks employees in capability domains from the InstructGPT paper, including brainstorming, @@ -103,15 +97,19 @@ DEFAULT_PROMPT_TEMPLATE = """{intro} def get_special_token_id(tokenizer: PreTrainedTokenizer, key: str) -> int: """Gets the token ID for a given string that has been added to the tokenizer as a special token. + When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to. + Args: - tokenizer (PreTrainedTokenizer): the tokenizer - key (str): the key to convert to a single token + tokenizer: the tokenizer + key: the key to convert to a single token + Raises: RuntimeError: if more than one ID was generated + Returns: - int: the token ID for the given key + int: the token ID for the given key. """ token_ids = tokenizer.encode(key) if len(token_ids) > 1: diff --git a/src/openllm/models/dolly_v2/modeling_dolly_v2.py b/src/openllm/models/dolly_v2/modeling_dolly_v2.py index 5b6a505d..c2b08358 100644 --- a/src/openllm/models/dolly_v2/modeling_dolly_v2.py +++ b/src/openllm/models/dolly_v2/modeling_dolly_v2.py @@ -12,24 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations - import logging import re import typing as t -import bentoml import openllm -from ...utils import normalize_attrs_to_model_tokenizer_pair from .configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE from .configuration_dolly_v2 import END_KEY from .configuration_dolly_v2 import RESPONSE_KEY from .configuration_dolly_v2 import get_special_token_id +from ...utils import normalize_attrs_to_model_tokenizer_pair if t.TYPE_CHECKING: import tensorflow as tf import torch + + import bentoml import transformers else: tf = openllm.utils.LazyLoader("tf", globals(), "tensorflow") @@ -75,14 +75,16 @@ def get_pipeline( top_k: int = 0, **kwargs: t.Any, ): - """Initialize the pipeline + """Initialize the pipeline. + Args: - do_sample (bool, optional): Whether or not to use sampling. Defaults to True. - max_new_tokens (int, optional): Max new tokens after the prompt to generate. Defaults to 128. - top_p (float, optional): If set to float < 1, only the smallest set of most probable tokens with - probabilities that add up to top_p or higher are kept for generation. Defaults to 0.92. - top_k (int, optional): The number of highest probability vocabulary tokens to keep for top-k-filtering. - Defaults to 0. + do_sample: Whether or not to use sampling. Defaults to True. + max_new_tokens: Max new tokens after the prompt to generate. Defaults to 128. + top_p: If set to float < 1, only the smallest set of most probable tokens with + probabilities that add up to top_p or higher are kept for generation. Defaults to 0.92. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to 0. + *args: Additional positional arguments to be passed to ``transformers.Pipeline``. + **kwargs: Additional keyword arguments to be passed to ``transformers.Pipeline``. """ super().__init__( *args, @@ -195,7 +197,7 @@ def get_pipeline( try: response_pos = sequence.index(response_key_token_id) except ValueError: - logger.warn(f"Could not find response key {response_key_token_id} in: {sequence}") + logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence) response_pos = None if response_pos: @@ -228,7 +230,7 @@ def get_pipeline( if m: decoded = m.group(1).strip() else: - logger.warn(f"Failed to find response in:\n{fully_decoded}") + logger.warning("Failed to find response in:\n%s", fully_decoded) # If the full text is requested, then append the decoded text to the original instruction. # This technically isn't the full text, as we format the instruction in the prompt the model has been diff --git a/src/openllm/models/falcon/__init__.py b/src/openllm/models/falcon/__init__.py index ca11b5cf..7351dc47 100644 --- a/src/openllm/models/falcon/__init__.py +++ b/src/openllm/models/falcon/__init__.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations - import typing as t -from ...utils import is_torch_available, LazyModule from ...exceptions import MissingDependencyError +from ...utils import LazyModule +from ...utils import is_torch_available _import_structure = { diff --git a/src/openllm/models/falcon/configuration_falcon.py b/src/openllm/models/falcon/configuration_falcon.py index 5e710572..2ebc6575 100644 --- a/src/openllm/models/falcon/configuration_falcon.py +++ b/src/openllm/models/falcon/configuration_falcon.py @@ -17,9 +17,9 @@ import openllm class FalconConfig(openllm.LLMConfig): - """Falcon-7B is a 7B parameters causal decoder-only model built by - TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) - enhanced with curated corpora. It is made available under the TII Falcon LLM License. + """Falcon-7B is a 7B parameters causal decoder-only model built by TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) enhanced with curated corpora. + + It is made available under the TII Falcon LLM License. Refer to [Falcon's HuggingFace page](https://huggingface.co/tiiuae/falcon-7b) for more information. """ diff --git a/src/openllm/models/falcon/modeling_falcon.py b/src/openllm/models/falcon/modeling_falcon.py index 7d6145c4..6ea8c4ef 100644 --- a/src/openllm/models/falcon/modeling_falcon.py +++ b/src/openllm/models/falcon/modeling_falcon.py @@ -13,19 +13,19 @@ # limitations under the License. from __future__ import annotations - import typing as t -import bentoml import openllm -from ..._prompt import default_formatter from .configuration_falcon import DEFAULT_PROMPT_TEMPLATE +from ..._prompt import default_formatter if t.TYPE_CHECKING: import torch import torch.amp + + import bentoml import transformers else: torch = openllm.utils.LazyLoader("torch", globals(), "torch") @@ -81,7 +81,7 @@ class Falcon(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTraine raise RuntimeError( f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. " "Use 'use_default_prompt_template=False' to disable the default prompt template." - ) + ) from None else: prompt_text = prompt diff --git a/src/openllm/models/flan_t5/__init__.py b/src/openllm/models/flan_t5/__init__.py index 39ce7517..119aa9b6 100644 --- a/src/openllm/models/flan_t5/__init__.py +++ b/src/openllm/models/flan_t5/__init__.py @@ -13,11 +13,13 @@ # limitations under the License. from __future__ import annotations - import typing as t -from ...utils import is_torch_available, is_tf_available, is_flax_available, LazyModule from ...exceptions import MissingDependencyError +from ...utils import LazyModule +from ...utils import is_flax_available +from ...utils import is_tf_available +from ...utils import is_torch_available _import_structure = { diff --git a/src/openllm/models/flan_t5/configuration_flan_t5.py b/src/openllm/models/flan_t5/configuration_flan_t5.py index 815132ac..76b175bd 100644 --- a/src/openllm/models/flan_t5/configuration_flan_t5.py +++ b/src/openllm/models/flan_t5/configuration_flan_t5.py @@ -46,8 +46,9 @@ DEFAULT_PROMPT_TEMPLATE = """Answer the following question:\nQuestion: {instruct class FlanT5Config(openllm.LLMConfig): - """FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf) - - it is an enhanced version of T5 that has been finetuned in a mixture of tasks. + """FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf). + + It is an enhanced version of T5 that has been finetuned in a mixture of tasks. Refer to [FLAN-T5's page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for more information. """ diff --git a/src/openllm/models/flan_t5/modeling_flan_t5.py b/src/openllm/models/flan_t5/modeling_flan_t5.py index 1bdaa0a6..f0341dfa 100644 --- a/src/openllm/models/flan_t5/modeling_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_flan_t5.py @@ -12,19 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations - import typing as t import openllm -from ..._prompt import default_formatter from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE +from ..._prompt import default_formatter if t.TYPE_CHECKING: import torch - import transformers # noqa + import transformers # noqa: F401 else: torch = openllm.utils.LazyLoader("torch", globals(), "torch") @@ -60,7 +59,7 @@ class FlanT5(openllm.LLM["transformers.T5ForConditionalGeneration", "transformer raise RuntimeError( f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. " "Use 'use_default_prompt_template=False' to disable the default prompt template." - ) + ) from None else: prompt_text = prompt diff --git a/src/openllm/models/flan_t5/modeling_flax_flan_t5.py b/src/openllm/models/flan_t5/modeling_flax_flan_t5.py index 829b8779..7e3eafd6 100644 --- a/src/openllm/models/flan_t5/modeling_flax_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_flax_flan_t5.py @@ -12,17 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations - import typing as t import openllm -from ..._prompt import default_formatter from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE +from ..._prompt import default_formatter if t.TYPE_CHECKING: - import transformers # noqa + import transformers # noqa: F401 class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "transformers.T5TokenizerFast"]): @@ -54,7 +53,7 @@ class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "tra raise RuntimeError( f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. " "Use 'use_default_prompt_template=False' to disable the default prompt template." - ) + ) from None else: prompt_text = prompt diff --git a/src/openllm/models/flan_t5/modeling_tf_flan_t5.py b/src/openllm/models/flan_t5/modeling_tf_flan_t5.py index 66cb36fb..34f63082 100644 --- a/src/openllm/models/flan_t5/modeling_tf_flan_t5.py +++ b/src/openllm/models/flan_t5/modeling_tf_flan_t5.py @@ -12,17 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations - import typing as t import openllm -from ..._prompt import default_formatter from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE +from ..._prompt import default_formatter if t.TYPE_CHECKING: - import transformers # noqa + import transformers # noqa: F401 class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transformers.T5TokenizerFast"]): @@ -40,17 +39,20 @@ class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transfo **attrs: t.Any, ) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]: if use_default_prompt_template: - prompt_variables = { - k: v - for k, v in attrs.items() - if k in default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE) - } + template_variables = default_formatter.extract_template_variables(DEFAULT_PROMPT_TEMPLATE) + prompt_variables = {k: v for k, v in attrs.items() if k in template_variables} if "instruction" in prompt_variables: raise RuntimeError( "'instruction' should be passed as the first argument " "instead of kwargs when 'use_default_prompt_template=True'" ) - prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables) + try: + prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=prompt, **prompt_variables) + except KeyError as e: + raise RuntimeError( + f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. " + "Use 'use_default_prompt_template=False' to disable the default prompt template." + ) from None else: prompt_text = prompt diff --git a/src/openllm/models/gpt_neox/__init__.py b/src/openllm/models/gpt_neox/__init__.py index b285f344..eac4d399 100644 --- a/src/openllm/models/gpt_neox/__init__.py +++ b/src/openllm/models/gpt_neox/__init__.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations - import typing as t -from ...utils import is_torch_available, LazyModule from ...exceptions import MissingDependencyError +from ...utils import LazyModule +from ...utils import is_torch_available _import_structure = { diff --git a/src/openllm/models/gpt_neox/configuration_gpt_neox.py b/src/openllm/models/gpt_neox/configuration_gpt_neox.py index 7eb077d9..a16e55b5 100644 --- a/src/openllm/models/gpt_neox/configuration_gpt_neox.py +++ b/src/openllm/models/gpt_neox/configuration_gpt_neox.py @@ -14,14 +14,13 @@ from __future__ import annotations -from __future__ import annotations - import openllm class GPTNeoXConfig(openllm.LLMConfig): - """GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and - openly available to the public through a permissive license. It is, to the best of our knowledge, the largest dense autoregressive model + """GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and openly available to the public through a permissive license. + + It is, to the best of our knowledge, the largest dense autoregressive model that has publicly available weights at the time of submission. The training and evaluation code, as well as the model weights, can be found at https://github.com/EleutherAI/gpt-neox. diff --git a/src/openllm/models/gpt_neox/modeling_gpt_neox.py b/src/openllm/models/gpt_neox/modeling_gpt_neox.py index 0d26be6b..911e99d1 100644 --- a/src/openllm/models/gpt_neox/modeling_gpt_neox.py +++ b/src/openllm/models/gpt_neox/modeling_gpt_neox.py @@ -13,21 +13,21 @@ # limitations under the License. from __future__ import annotations - import logging import typing as t import openllm -from ..._prompt import default_formatter from .configuration_gpt_neox import DEFAULT_PROMPT_TEMPLATE +from ..._prompt import default_formatter if t.TYPE_CHECKING: - import bentoml - import transformers # noqa import torch import torch.amp + + import bentoml + import transformers else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") torch = openllm.utils.LazyLoader("torch", globals(), "torch") @@ -62,7 +62,7 @@ class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNe raise RuntimeError( f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. " "Use 'use_default_prompt_template=False' to disable the default prompt template." - ) + ) from None else: prompt_text = prompt diff --git a/src/openllm/models/mpt/__init__.py b/src/openllm/models/mpt/__init__.py index 2de02b2c..19010afb 100644 --- a/src/openllm/models/mpt/__init__.py +++ b/src/openllm/models/mpt/__init__.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations - import typing as t -from ...utils import is_torch_available, LazyModule from ...exceptions import MissingDependencyError +from ...utils import LazyModule +from ...utils import is_torch_available _import_structure = { diff --git a/src/openllm/models/mpt/configuration_mpt.py b/src/openllm/models/mpt/configuration_mpt.py index de20c19a..fffc2047 100644 --- a/src/openllm/models/mpt/configuration_mpt.py +++ b/src/openllm/models/mpt/configuration_mpt.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import typing as t import openllm @@ -27,10 +26,11 @@ else: class MPTConfig(openllm.LLMConfig): - """MPT is a decoder-style transformer pretrained from scratch on - English text and code. This model was trained by [MosaicML](https://www.mosaicml.com/). + """MPT is a decoder-style transformer pretrained from scratch on English text and code. - `openllm.MPT` encapsulate a family of MPT variants that is publicly available + This model was trained by [MosaicML](https://www.mosaicml.com/). + + ``openllm.MPT`` encapsulate a family of MPT variants that is publicly available on HuggingFace. Refers [HuggingFace's MosaicML page](https://huggingface.co/mosaicml) for more details on specific models. """ diff --git a/src/openllm/models/mpt/modeling_mpt.py b/src/openllm/models/mpt/modeling_mpt.py index 45618e4c..28acec5f 100644 --- a/src/openllm/models/mpt/modeling_mpt.py +++ b/src/openllm/models/mpt/modeling_mpt.py @@ -13,16 +13,16 @@ # limitations under the License. from __future__ import annotations - import logging import typing as t import bentoml import openllm -from ..._prompt import default_formatter -from ...utils import is_triton_available from .configuration_mpt import DEFAULT_PROMPT_TEMPLATE +from ..._prompt import default_formatter +from ...utils import generate_labels +from ...utils import is_triton_available if t.TYPE_CHECKING: @@ -78,8 +78,7 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken return model_kwds, tokenizer_kwds def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model: - (_, model_attrs), tokenizer_kwds = self.llm_parameters - attrs = {**model_attrs, **attrs} + _, tokenizer_attrs = self.llm_parameters torch_dtype = attrs.pop("torch_dtype", self.dtype) device_map = attrs.pop("device_map", None) @@ -93,7 +92,7 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken trust_remote_code=trust_remote_code, ) - tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds) + tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs) if tokenizer.pad_token_id is None: logger.warning("pad_token_id is not set. Setting it to eos_token") tokenizer.pad_token = tokenizer.eos_token @@ -107,7 +106,12 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken **attrs, ) try: - return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}) + return bentoml.transformers.save_model( + self.tag, + model, + custom_objects={"tokenizer": tokenizer}, + labels=generate_labels(self), + ) finally: torch.cuda.empty_cache() @@ -169,7 +173,7 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken raise RuntimeError( f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. " "Use 'use_default_prompt_template=False' to disable the default prompt template." - ) + ) from None else: prompt_text = prompt diff --git a/src/openllm/models/opt/__init__.py b/src/openllm/models/opt/__init__.py index 032aae60..c55356e1 100644 --- a/src/openllm/models/opt/__init__.py +++ b/src/openllm/models/opt/__init__.py @@ -13,11 +13,13 @@ # limitations under the License. from __future__ import annotations - import typing as t -from ...utils import is_torch_available, LazyModule, is_flax_available, is_tf_available from ...exceptions import MissingDependencyError +from ...utils import LazyModule +from ...utils import is_flax_available +from ...utils import is_tf_available +from ...utils import is_torch_available _import_structure = { diff --git a/src/openllm/models/opt/configuration_opt.py b/src/openllm/models/opt/configuration_opt.py index dbb44894..0a8facac 100644 --- a/src/openllm/models/opt/configuration_opt.py +++ b/src/openllm/models/opt/configuration_opt.py @@ -18,9 +18,7 @@ import openllm class OPTConfig(openllm.LLMConfig): - """OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) - and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq) - on May 3rd 2022 by Meta AI. + """OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq) on May 3rd 2022 by Meta AI. OPT was predominantly pretrained with English text, but a small amount of non-English data is still present within the training corpus via CommonCrawl. The model was pretrained using a causal language modeling (CLM) diff --git a/src/openllm/models/opt/modeling_flax_opt.py b/src/openllm/models/opt/modeling_flax_opt.py index 533c3e4d..00e1a0f1 100644 --- a/src/openllm/models/opt/modeling_flax_opt.py +++ b/src/openllm/models/opt/modeling_flax_opt.py @@ -13,15 +13,15 @@ # limitations under the License. from __future__ import annotations - import logging import typing as t import bentoml import openllm -from ..._prompt import default_formatter from .configuration_opt import DEFAULT_PROMPT_TEMPLATE +from ..._prompt import default_formatter +from ...utils import generate_labels if t.TYPE_CHECKING: @@ -44,17 +44,21 @@ class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tok return {}, tokenizer_kwds def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model: - (_, model_attrs), tokenizer_kwds = self.llm_parameters - attrs = {**model_attrs, **attrs} + _, tokenizer_attrs = self.llm_parameters config = transformers.AutoConfig.from_pretrained(self.model_id) - tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds) + tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs) tokenizer.pad_token_id = config.pad_token_id model = t.cast( "transformers.FlaxOPTForCausalLM", transformers.FlaxAutoModelForCausalLM.from_pretrained(self.model_id, **attrs), ) - return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}) + return bentoml.transformers.save_model( + self.tag, + model, + custom_objects={"tokenizer": tokenizer}, + labels=generate_labels(self), + ) def sanitize_parameters( self, @@ -81,7 +85,7 @@ class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tok raise RuntimeError( f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. " "Use 'use_default_prompt_template=False' to disable the default prompt template." - ) + ) from None else: prompt_text = prompt diff --git a/src/openllm/models/opt/modeling_opt.py b/src/openllm/models/opt/modeling_opt.py index ac17b3c7..c9e00bb6 100644 --- a/src/openllm/models/opt/modeling_opt.py +++ b/src/openllm/models/opt/modeling_opt.py @@ -13,21 +13,21 @@ # limitations under the License. from __future__ import annotations - import logging import typing as t import bentoml import openllm -from ..._prompt import default_formatter from .configuration_opt import DEFAULT_PROMPT_TEMPLATE +from ..._prompt import default_formatter +from ...utils import generate_labels if t.TYPE_CHECKING: import torch - import transformers # noqa + import transformers else: torch = openllm.utils.LazyLoader("torch", globals(), "torch") transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") @@ -55,13 +55,12 @@ class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer return model_kwds, tokenizer_kwds def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model: - (_, model_attrs), tokenizer_kwds = self.llm_parameters - attrs = {**model_attrs, **attrs} + _, tokenizer_attrs = self.llm_parameters torch_dtype = attrs.pop("torch_dtype", self.dtype) config = transformers.AutoConfig.from_pretrained(self.model_id) - tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds) + tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs) tokenizer.pad_token_id = config.pad_token_id model = t.cast( "transformers.OPTForCausalLM", @@ -69,7 +68,12 @@ class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer self.model_id, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, **attrs ), ) - return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}) + return bentoml.transformers.save_model( + self.tag, + model, + custom_objects={"tokenizer": tokenizer}, + labels=generate_labels(self), + ) def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> transformers.OPTForCausalLM: torch_dtype = attrs.pop("torch_dtype", self.dtype) @@ -105,7 +109,7 @@ class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer raise RuntimeError( f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. " "Use 'use_default_prompt_template=False' to disable the default prompt template." - ) + ) from None else: prompt_text = prompt diff --git a/src/openllm/models/opt/modeling_tf_opt.py b/src/openllm/models/opt/modeling_tf_opt.py index aa9a7e0f..cada4e3d 100644 --- a/src/openllm/models/opt/modeling_tf_opt.py +++ b/src/openllm/models/opt/modeling_tf_opt.py @@ -13,15 +13,15 @@ # limitations under the License. from __future__ import annotations - import logging import typing as t import bentoml import openllm -from ..._prompt import default_formatter from .configuration_opt import DEFAULT_PROMPT_TEMPLATE +from ..._prompt import default_formatter +from ...utils import generate_labels if t.TYPE_CHECKING: @@ -45,16 +45,20 @@ class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Token return {}, tokenizer_kwds def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model: - (_, model_attrs), tokenizer_kwds = self.llm_parameters - attrs = {**model_attrs, **attrs} + _, tokenizer_attrs = self.llm_parameters config = transformers.AutoConfig.from_pretrained(self.model_id) - tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds) + tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs) tokenizer.pad_token_id = config.pad_token_id model: transformers.TFOPTForCausalLM = transformers.TFOPTForCausalLM.from_pretrained( self.model_id, trust_remote_code=trust_remote_code, **attrs ) - return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}) + return bentoml.transformers.save_model( + self.tag, + model, + custom_objects={"tokenizer": tokenizer}, + labels=generate_labels(self), + ) def sanitize_parameters( self, @@ -80,7 +84,7 @@ class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Token raise RuntimeError( f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. " "Use 'use_default_prompt_template=False' to disable the default prompt template." - ) + ) from None else: prompt_text = prompt diff --git a/src/openllm/models/stablelm/__init__.py b/src/openllm/models/stablelm/__init__.py index b99e66e7..358146e4 100644 --- a/src/openllm/models/stablelm/__init__.py +++ b/src/openllm/models/stablelm/__init__.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations - import typing as t -from ...utils import is_torch_available, LazyModule from ...exceptions import MissingDependencyError +from ...utils import LazyModule +from ...utils import is_torch_available _import_structure = { diff --git a/src/openllm/models/stablelm/configuration_stablelm.py b/src/openllm/models/stablelm/configuration_stablelm.py index aafde6db..696d42f8 100644 --- a/src/openllm/models/stablelm/configuration_stablelm.py +++ b/src/openllm/models/stablelm/configuration_stablelm.py @@ -17,8 +17,9 @@ import openllm class StableLMConfig(openllm.LLMConfig): - """StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models - pre-trained on a diverse collection of English datasets with a sequence + """StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models. + + It is pre-trained on a diverse collection of English datasets with a sequence length of 4096 to push beyond the context window limitations of existing open-source language models. StableLM-Tuned-Alpha is a suite of 3B and 7B parameter decoder-only language models @@ -74,6 +75,6 @@ SYSTEM_PROMPT = """<|SYSTEM|># StableLM Tuned (Alpha version) - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. - StableLM will refuse to participate in anything that could harm a human. -""" # noqa +""" DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>""" diff --git a/src/openllm/models/stablelm/modeling_stablelm.py b/src/openllm/models/stablelm/modeling_stablelm.py index ec9de35d..5dc2051c 100644 --- a/src/openllm/models/stablelm/modeling_stablelm.py +++ b/src/openllm/models/stablelm/modeling_stablelm.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations - import logging import typing as t import openllm -from ..._prompt import default_formatter from .configuration_stablelm import DEFAULT_PROMPT_TEMPLATE from .configuration_stablelm import SYSTEM_PROMPT +from ..._prompt import default_formatter if t.TYPE_CHECKING: diff --git a/src/openllm/models/starcoder/__init__.py b/src/openllm/models/starcoder/__init__.py index 5bdc2da0..3954349b 100644 --- a/src/openllm/models/starcoder/__init__.py +++ b/src/openllm/models/starcoder/__init__.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations - import typing as t -from ...utils import is_torch_available, LazyModule from ...exceptions import MissingDependencyError +from ...utils import LazyModule +from ...utils import is_torch_available _import_structure = { diff --git a/src/openllm/models/starcoder/configuration_starcoder.py b/src/openllm/models/starcoder/configuration_starcoder.py index 807952cf..bb40d22c 100644 --- a/src/openllm/models/starcoder/configuration_starcoder.py +++ b/src/openllm/models/starcoder/configuration_starcoder.py @@ -17,8 +17,7 @@ import openllm class StarCoderConfig(openllm.LLMConfig): - """The StarCoder models are 15.5B parameter models trained on 80+ programming languages from - [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded. + """The StarCoder models are 15.5B parameter models trained on 80+ programming languages from [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded. The model uses [Multi Query Attention](https://arxiv.org/abs/1911.02150), [a context window of 8192 tokens](https://arxiv.org/abs/2205.14135), and was trained using the diff --git a/src/openllm/models/starcoder/modeling_starcoder.py b/src/openllm/models/starcoder/modeling_starcoder.py index b10d2898..7f3a3e85 100644 --- a/src/openllm/models/starcoder/modeling_starcoder.py +++ b/src/openllm/models/starcoder/modeling_starcoder.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations - import logging import typing as t import bentoml import openllm +from ...utils import generate_labels + if t.TYPE_CHECKING: import torch @@ -54,13 +55,12 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers. return model_kwds, tokenizer_kwds def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model: - (_, model_attrs), tokenizer_kwds = self.llm_parameters - attrs = {**model_attrs, **attrs} + _, tokenizer_attrs = self.llm_parameters torch_dtype = attrs.pop("torch_dtype", torch.float16) device_map = attrs.pop("device_map", "auto") - tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_kwds) + tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs) tokenizer.add_special_tokens( { "additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], @@ -72,7 +72,12 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers. self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs ) try: - return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}) + return bentoml.transformers.save_model( + self.tag, + model, + custom_objects={"tokenizer": tokenizer}, + labels=generate_labels(self), + ) finally: # NOTE: We need to free the cache after saving here so that we can load it back later on. torch.cuda.empty_cache() diff --git a/src/openllm/playground/falcon_tuned.py b/src/openllm/playground/falcon_tuned.py index 18cef46f..8d0030d3 100644 --- a/src/openllm/playground/falcon_tuned.py +++ b/src/openllm/playground/falcon_tuned.py @@ -1,5 +1,4 @@ from __future__ import annotations - import dataclasses import logging import os diff --git a/src/openllm/playground/features.py b/src/openllm/playground/features.py index 59322397..c47075e0 100644 --- a/src/openllm/playground/features.py +++ b/src/openllm/playground/features.py @@ -1,5 +1,4 @@ from __future__ import annotations - import argparse import logging import typing as t diff --git a/src/openllm/playground/opt_tuned.py b/src/openllm/playground/opt_tuned.py index a5273b33..020220af 100644 --- a/src/openllm/playground/opt_tuned.py +++ b/src/openllm/playground/opt_tuned.py @@ -1,5 +1,4 @@ from __future__ import annotations - import dataclasses import logging import os diff --git a/src/openllm/serialisation/__init__.py b/src/openllm/serialisation/__init__.py index 8be0ca97..87f70f2a 100644 --- a/src/openllm/serialisation/__init__.py +++ b/src/openllm/serialisation/__init__.py @@ -37,15 +37,20 @@ llm.save_pretrained("./path/to/local-dolly") """ from __future__ import annotations +import typing as t + +import openllm from ..utils import LazyModule -import typing as t -import openllm + if t.TYPE_CHECKING: import bentoml - from .._types import ModelProtocol, TokenizerProtocol - from .transformers import _M, _T + + from .._llm import M + from .._llm import T + from .._types import ModelProtocol + from .._types import TokenizerProtocol def import_model( @@ -80,7 +85,7 @@ def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs raise ValueError(f"Unknown runtime: {llm.config['runtime']}") -def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[_M]: +def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[M]: if llm.runtime == "transformers": return openllm.transformers.load_model(llm, *decls, **attrs) elif llm.runtime == "ggml": @@ -89,7 +94,7 @@ def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> Mo raise ValueError(f"Unknown runtime: {llm.config['runtime']}") -def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]: +def load_tokenizer(llm: openllm.LLM[t.Any, T]) -> TokenizerProtocol[T]: if llm.runtime == "transformers": return openllm.transformers.load_tokenizer(llm) elif llm.runtime == "ggml": @@ -109,11 +114,6 @@ _extras = { _import_structure: dict[str, list[str]] = {"ggml": [], "transformers": []} if t.TYPE_CHECKING: - from . import import_model as import_model - from . import get as get - from . import save_pretrained as save_pretrained - from . import load_model as load_model - from . import load_tokenizer as load_tokenizer from . import ggml as ggml from . import transformers as transformers else: diff --git a/src/openllm/serialisation/ggml.py b/src/openllm/serialisation/ggml.py index ae79b8f1..430a94f7 100644 --- a/src/openllm/serialisation/ggml.py +++ b/src/openllm/serialisation/ggml.py @@ -16,19 +16,25 @@ This requires ctransformers to be installed. """ from __future__ import annotations - -import openllm import typing as t -import bentoml + import cloudpickle -from ..exceptions import OpenLLMException -from ..utils import LazyLoader + +import bentoml from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME +from ..exceptions import OpenLLMException +from ..utils import LazyLoader + + if t.TYPE_CHECKING: - from .._types import ModelProtocol, TokenizerProtocol - from .transformers import _M, _T + import openllm import transformers + + from .._llm import M + from .._llm import T + from .._types import ModelProtocol + from .._types import TokenizerProtocol else: transformers = LazyLoader("transformers", globals(), "transformers") @@ -44,6 +50,7 @@ def import_model( def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model: """Return an instance of ``bentoml.Model`` from given LLM instance. + By default, it will try to check the model in the local store. If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub. @@ -66,15 +73,16 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo raise -def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[_M]: +def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[M]: """Load the model from BentoML store. + By default, it will try to find check the model in the local store. If model is not found, it will raises a ``bentoml.exceptions.NotFound``. """ raise NotImplementedError("Currently work in progress.") -def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]: +def load_tokenizer(llm: openllm.LLM[t.Any, T]) -> TokenizerProtocol[T]: """Load the tokenizer from BentoML store. By default, it will try to find the bentomodel whether it is in store.. @@ -95,14 +103,14 @@ def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]: "Model does not have tokenizer. Make sure to save \ the tokenizer within the model via 'custom_objects'.\ For example: bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer}))" - ) + ) from None else: tokenizer = transformers.AutoTokenizer.from_pretrained( bentomodel_fs.getsyspath("/"), trust_remote_code=llm.__llm_trust_remote_code__, **tokenizer_attrs, ) - return t.cast("TokenizerProtocol[_T]", tokenizer) + return t.cast("TokenizerProtocol[T]", tokenizer) def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any): diff --git a/src/openllm/serialisation/transformers.py b/src/openllm/serialisation/transformers.py index cc2a1b1c..5d48c80d 100644 --- a/src/openllm/serialisation/transformers.py +++ b/src/openllm/serialisation/transformers.py @@ -11,32 +11,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Serialisation related implementation for Transformers-based implementation. -""" +"""Serialisation related implementation for Transformers-based implementation.""" from __future__ import annotations - import copy -import openllm -import typing as t import importlib +import typing as t + +import cloudpickle + import bentoml from bentoml._internal.frameworks.transformers import make_default_signatures -from bentoml._internal.models.model import ModelOptions -from ..exceptions import OpenLLMException -import cloudpickle from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME -from ..utils import LazyLoader, is_torch_available -from ..utils import generate_context, normalize_attrs_to_model_tokenizer_pair -from .constants import FRAMEWORK_TO_AUTOCLASS_MAPPING, MODEL_TO_AUTOCLASS_MAPPING +from bentoml._internal.models.model import ModelOptions + +from .constants import FRAMEWORK_TO_AUTOCLASS_MAPPING +from .constants import MODEL_TO_AUTOCLASS_MAPPING +from ..exceptions import OpenLLMException +from ..utils import LazyLoader +from ..utils import first_not_none +from ..utils import generate_context +from ..utils import generate_labels +from ..utils import is_torch_available +from ..utils import normalize_attrs_to_model_tokenizer_pair + if t.TYPE_CHECKING: - import transformers import torch - from .._types import P - from .._llm import _M, _T - from .._types import DictStrAny, ModelProtocol, TokenizerProtocol + + import openllm + import transformers from transformers.models.auto.auto_factory import _BaseAutoModelClass + + from .._llm import M + from .._llm import T + from .._types import DictStrAny + from .._types import ModelProtocol + from .._types import TokenizerProtocol else: transformers = LazyLoader("transformers", globals(), "transformers") torch = LazyLoader("torch", globals(), "torch") @@ -46,7 +57,6 @@ def process_transformers_config( model_id: str, trust_remote_code: bool, **attrs: t.Any ) -> tuple[transformers.PretrainedConfig, dict[str, t.Any], dict[str, t.Any]]: """Process transformers config and return PretrainedConfig with hub_kwargs and the rest of kwargs.""" - config: transformers.PretrainedConfig = attrs.pop("config", None) # this logic below is synonymous to handling `from_pretrained` attrs. @@ -123,7 +133,7 @@ def import_model( attrs = {**model_attrs, **attrs} tokenizer = t.cast( - transformers.PreTrainedTokenizer, + "transformers.PreTrainedTokenizer", transformers.AutoTokenizer.from_pretrained( llm.model_id, config=config, @@ -133,13 +143,16 @@ def import_model( ), ) - model = infer_autoclass_from_llm_config(llm, config).from_pretrained( - llm.model_id, - *decls, - config=config, - trust_remote_code=trust_remote_code, - **hub_attrs, - **attrs, + model = t.cast( + "transformers.PreTrainedModel", + infer_autoclass_from_llm_config(llm, config).from_pretrained( + llm.model_id, + *decls, + config=config, + trust_remote_code=trust_remote_code, + **hub_attrs, + **attrs, + ), ) try: @@ -148,7 +161,7 @@ def import_model( module="openllm.serialisation.transformers", api_version="v1", context=generate_context(framework_name="openllm"), - labels={"runtime": llm.runtime}, + labels=generate_labels(llm), options=ModelOptions(), signatures=make_default_signatures(model), external_modules=[ @@ -173,6 +186,7 @@ def import_model( def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model: """Return an instance of ``bentoml.Model`` from given LLM instance. + By default, it will try to check the model in the local store. If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub. @@ -201,8 +215,9 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo raise -def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[_M]: +def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> ModelProtocol[M]: """Load the model from BentoML store. + By default, it will try to find check the model in the local store. If model is not found, it will raises a ``bentoml.exceptions.NotFound``. """ @@ -224,11 +239,11 @@ def load_model(llm: openllm.LLM[_M, t.Any], *decls: t.Any, **attrs: t.Any) -> Mo # BetterTransformer is currently only supported on PyTorch. from optimum.bettertransformer import BetterTransformer - model = BetterTransformer.transform(model) - return t.cast("ModelProtocol[_M]", model) + model = BetterTransformer.transform(model) # type: ignore + return t.cast("ModelProtocol[M]", model) -def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]: +def load_tokenizer(llm: openllm.LLM[t.Any, T]) -> TokenizerProtocol[T]: """Load the tokenizer from BentoML store. By default, it will try to find the bentomodel whether it is in store.. @@ -249,14 +264,14 @@ def load_tokenizer(llm: openllm.LLM[t.Any, _T]) -> TokenizerProtocol[_T]: "Model does not have tokenizer. Make sure to save \ the tokenizer within the model via 'custom_objects'.\ For example: bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer}))" - ) + ) from None else: tokenizer = transformers.AutoTokenizer.from_pretrained( bentomodel_fs.getsyspath("/"), trust_remote_code=llm.__llm_trust_remote_code__, **tokenizer_attrs, ) - return t.cast("TokenizerProtocol[_T]", tokenizer) + return tokenizer def save_pretrained( @@ -264,7 +279,7 @@ def save_pretrained( save_directory: str, is_main_process: bool = True, state_dict: DictStrAny | None = None, - save_function: t.Callable[P, None] | None = None, + save_function: t.Callable[..., None] | None = None, push_to_hub: bool = False, max_shard_size: int | str = "10GB", safe_serialization: bool = False, @@ -272,8 +287,7 @@ def save_pretrained( **attrs: t.Any, ): """Light wrapper around ``transformers.PreTrainedTokenizer.save_pretrained`` and ``transformers.PreTrainedModel.save_pretrained``.""" - if save_function is None: - save_function = torch.save + save_function = first_not_none(save_function, default=torch.save) model_save_attrs, tokenizer_save_attrs = normalize_attrs_to_model_tokenizer_pair(**attrs) diff --git a/src/openllm/testing.py b/src/openllm/testing.py new file mode 100644 index 00000000..44d69334 --- /dev/null +++ b/src/openllm/testing.py @@ -0,0 +1,113 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests utilities for OpenLLM.""" + +from __future__ import annotations +import contextlib +import logging +import shutil +import subprocess +import typing as t + +import bentoml +import openllm + + +logger = logging.getLogger(__name__) + +if t.TYPE_CHECKING: + from ._types import LiteralRuntime + + +@contextlib.contextmanager +def build_bento( + model: str, + model_id: str | None = None, + quantize: t.Literal["int4", "int8", "gptq"] | None = None, + runtime: t.Literal["ggml", "transformers"] = "transformers", + cleanup: bool = False, +): + logger.info("Building BentoML for %s", model) + bento = openllm.build(model, model_id=model_id, quantize=quantize, runtime=runtime) + yield bento + if cleanup: + logger.info("Deleting %s", bento.tag) + bentoml.bentos.delete(bento.tag) + + +@contextlib.contextmanager +def build_container( + bento: bentoml.Bento | str | bentoml.Tag, + image_tag: str | None = None, + cleanup: bool = False, + **attrs: t.Any, +): + if isinstance(bento, bentoml.Bento): + bento_tag = bento.tag + else: + bento_tag = bentoml.Tag.from_taglike(bento) + + if image_tag is None: + image_tag = str(bento_tag) + + executable = shutil.which("docker") + if not executable: + raise RuntimeError("docker executable not found") + + try: + logger.info("Building container for %s", bento_tag) + bentoml.container.build( + bento_tag, + backend="docker", + image_tag=(image_tag,), + progress="plain", + **attrs, + ) + yield image_tag + finally: + if cleanup: + logger.info("Deleting container %s", image_tag) + subprocess.check_output([executable, "rmi", "-f", image_tag]) + + +@contextlib.contextmanager +def prepare( + model: str, + model_id: str | None = None, + implementation: LiteralRuntime = "pt", + deployment_mode: t.Literal["container", "local"] = "local", + clean_context: contextlib.ExitStack | None = None, + cleanup: bool = False, +): + if clean_context is None: + clean_context = contextlib.ExitStack() + cleanup = True + + llm = openllm.infer_auto_class(implementation).for_model(model, model_id=model_id, ensure_available=True) + bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}") + + if not bentoml.list(bento_tag): + bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup)) + else: + bento = bentoml.get(bento_tag) + + container_name = f"openllm-{model}-{llm.llm_type}".replace("-", "_") + + if deployment_mode == "container": + container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup)) + + yield container_name + if cleanup: + clean_context.close() diff --git a/src/openllm/tests.py b/src/openllm/tests.py deleted file mode 100644 index 201d6804..00000000 --- a/src/openllm/tests.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import typing as t - -from .utils import is_flax_available -from .utils import is_tf_available -from .utils import is_torch_available - - -try: - import pytest -except ImportError: - raise ImportError("You need to install pytest to use 'openllm.tests' utilities: 'pip install pytest'") - - -def require_tf(f: t.Callable[..., t.Any]): - return pytest.mark.skipif(not is_tf_available(), reason="requires TensorFlow")(f) - - -def require_flax(f: t.Callable[..., t.Any]): - return pytest.mark.skipif(not is_flax_available(), reason="requires Flax")(f) - - -def require_torch(f: t.Callable[..., t.Any]): - return pytest.mark.skipif(not is_torch_available(), reason="requires PyTorch")(f) diff --git a/src/openllm/utils/__init__.py b/src/openllm/utils/__init__.py index c4b05239..2a1f3a45 100644 --- a/src/openllm/utils/__init__.py +++ b/src/openllm/utils/__init__.py @@ -11,30 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Utilities function for OpenLLM. User can import these function for convenience, but +"""Utilities function for OpenLLM. + +User can import these function for convenience, but we won't ensure backward compatibility for these functions. So use with caution. """ -from __future__ import annotations as _annotations - +from __future__ import annotations +import contextlib import functools import logging import logging.config import os import sys -import platform import types import typing as t +from pathlib import Path +from circus.exc import ConflictError + +from bentoml._internal.configuration import DEBUG_ENV_VAR as _DEBUG_ENV_VAR +from bentoml._internal.configuration import GRPC_DEBUG_ENV_VAR as _GRPC_DEBUG_ENV_VAR from bentoml._internal.configuration import get_debug_mode from bentoml._internal.configuration import get_quiet_mode -from bentoml._internal.configuration import set_debug_mode from bentoml._internal.configuration import set_quiet_mode +from bentoml._internal.log import configure_server_logging from bentoml._internal.models.model import ModelContext as _ModelContext -from bentoml._internal.log import CLI_LOGGING_CONFIG as _CLI_LOGGING_CONFIG from bentoml._internal.types import LazyType from bentoml._internal.utils import LazyLoader from bentoml._internal.utils import bentoml_cattr +from bentoml._internal.utils import cached_contextmanager from bentoml._internal.utils import copy_file_to_fs_folder from bentoml._internal.utils import first_not_none from bentoml._internal.utils import pkg @@ -62,8 +67,26 @@ else: types.UnionType, ) +# NOTE: We need to do this so that overload can register +# correct overloads to typing registry +if sys.version_info[:2] >= (3, 11): + from typing import overload as _overload +else: + from typing_extensions import overload as _overload + if t.TYPE_CHECKING: + import openllm + from .._types import DictStrAny + from .._types import LiteralRuntime + from .._types import P + from ..models.auto.factory import BaseAutoLLMClass + + +def set_debug_mode(enabled: bool): + # monkeypatch bentoml._internal.configuration.set_debug_mode to remove unused logs + os.environ[_DEBUG_ENV_VAR] = str(enabled) + os.environ[_GRPC_DEBUG_ENV_VAR] = "DEBUG" if enabled else "ERROR" def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.Any], ...] | None) -> bool: @@ -75,16 +98,12 @@ def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.An raise -def gpu_count() -> tuple[int, ...]: +def gpu_count() -> tuple[str, ...]: from bentoml._internal.resource import NvidiaGpuResource cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) if cuda_visible_devices is not None: - if "," in cuda_visible_devices: - available_gpu = tuple(int(i) for i in cuda_visible_devices.split(",")) - else: - available_gpu = tuple(int(i) for i in cuda_visible_devices.split()) - return available_gpu + return tuple(i for i in cuda_visible_devices.split(",")) return tuple(NvidiaGpuResource.from_system()) @@ -94,7 +113,7 @@ _object_setattr = object.__setattr__ def non_intrusive_setattr(obj: t.Any, name: str, value: t.Any) -> None: - """This makes sure that we don't overwrite any existing attributes on the object""" + """This makes sure that we don't overwrite any existing attributes on the object.""" _setattr = functools.partial(setattr, obj) if isinstance(obj, type) else _object_setattr.__get__(obj) if not hasattr(obj, name): @@ -107,22 +126,65 @@ def field_env_key(model_name: str, key: str, suffix: str | t.Literal[""] | None DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.environ.get("OPENLLMDEVDEBUG"))) +SHOW_CODEGEN = DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3 -_LOGGING_CONFIG = _CLI_LOGGING_CONFIG.copy() -_LOGGING_CONFIG["loggers"].update( - { - "openllm": { - "level": logging.INFO, + +class _ExceptionFilter(logging.Filter): + def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any): + if exclude_exceptions is None: + exclude_exceptions = [ConflictError] + else: + exclude_exceptions.append(ConflictError) + + super(_ExceptionFilter, self).__init__(**kwargs) + self.EXCLUDE_EXCEPTIONS = exclude_exceptions + + def filter(self, record: logging.LogRecord) -> bool: + if record.exc_info: + etype, _, _ = record.exc_info + if etype is not None: + for exc in self.EXCLUDE_EXCEPTIONS: + if issubclass(etype, exc): + return False + return True + + +_LOGGING_CONFIG: DictStrAny = { + "version": 1, + "disable_existing_loggers": True, + "filters": {"excfilter": {"()": _ExceptionFilter}}, + "handlers": { + "bentomlhandler": { + "class": "logging.StreamHandler", + "filters": ["excfilter"], + "stream": "ext://sys.stdout", + }, + "defaulthandler": { + "class": "logging.StreamHandler", + "level": logging.WARNING, + }, + }, + "loggers": { + "bentoml": { "handlers": ["bentomlhandler", "defaulthandler"], + "level": logging.INFO, "propagate": False, - } - } -) + }, + "openllm": { + "handlers": ["bentomlhandler", "defaulthandler"], + "level": logging.INFO, + "propagate": False, + }, + }, + "root": {"level": logging.WARNING}, +} def configure_logging() -> None: - """Configure logging for OpenLLM. Behaves similar to how BentoML loggers - are being configured.""" + """Configure logging for OpenLLM. + + Behaves similar to how BentoML loggers are being configured. + """ if get_quiet_mode(): _LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR _LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR @@ -144,7 +206,7 @@ def in_notebook() -> bool: try: from IPython.core.getipython import get_ipython - if "IPKernelApp" not in get_ipython().config: # pragma: no cover + if "IPKernelApp" not in get_ipython().config: # type: ignore return False except ImportError: return False @@ -153,8 +215,91 @@ def in_notebook() -> bool: return True +_dockerenv = Path("/.dockerenv") +_cgroup = Path("/proc/self/cgroup") + + +class suppress(contextlib.suppress, contextlib.ContextDecorator): + """A version of contextlib.suppress with decorator support. + + >>> @suppress(KeyError) + ... def key_error(): + ... {}[''] + >>> key_error() + """ + + +def compose(*funcs: t.Callable[..., t.Any]): + """Compose any number of unary functions into a single unary function. + + >>> import textwrap + >>> expected = str.strip(textwrap.dedent(compose.__doc__)) + >>> strip_and_dedent = compose(str.strip, textwrap.dedent) + >>> strip_and_dedent(compose.__doc__) == expected + True + + Compose also allows the innermost function to take arbitrary arguments. + + >>> round_three = lambda x: round(x, ndigits=3) + >>> f = compose(round_three, int.__truediv__) + >>> [f(3*x, x+1) for x in range(1,10)] + [1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7] + """ + + def compose_two(f1: t.Callable[..., t.Any], f2: t.Callable[P, t.Any]): + def _(*args: P.args, **kwargs: P.kwargs) -> t.Any: + return f1(f2(*args, **kwargs)) + + return _ + + return functools.reduce(compose_two, funcs) + + +def apply(transform: t.Callable[..., t.Any]): + """Decorate a function with a transform function that is invoked on results returned from the decorated function. + + ```python + @apply(reversed) + def get_numbers(start): + "doc for get_numbers" + return range(start, start+3) + list(get_numbers(4)) + # [6, 5, 4] + ``` + ```python + get_numbers.__doc__ + # 'doc for get_numbers' + ``` + """ + + def wrap(func: t.Callable[P, t.Any]): + return functools.wraps(func)(compose(transform, func)) + + return wrap + + +@apply(bool) +@suppress(FileNotFoundError) +def _text_in_file(text: str, filename: Path): + return any(text in line for line in filename.open()) + + +def in_docker() -> bool: + """Is this current environment running in docker? + + ```python + type(in_docker()) + ``` + """ + return _dockerenv.exists() or _text_in_file("docker", _cgroup) + + +T = t.TypeVar("T") +K = t.TypeVar("K") + + def resolve_filepath(path: str) -> str: - """Resolve a file path to an absolute path, expand user and environment variables""" + """Resolve a file path to an absolute path, expand user and environment variables.""" try: return resolve_user_filepath(path, None) except FileNotFoundError: @@ -166,7 +311,9 @@ def validate_is_path(maybe_path: str) -> bool: def generate_context(framework_name: str) -> _ModelContext: - from .import_utils import is_torch_available, is_flax_available, is_tf_available + from .import_utils import is_flax_available + from .import_utils import is_tf_available + from .import_utils import is_torch_available framework_versions = {"transformers": pkg.get_pkg_version("transformers")} if is_torch_available(): @@ -174,7 +321,7 @@ def generate_context(framework_name: str) -> _ModelContext: if is_tf_available(): from bentoml._internal.frameworks.utils.tensorflow import get_tf_version - framework_versions["tensorflow-macos" if platform.system() == "Darwin" else "tensorflow"] = get_tf_version() + framework_versions["tensorflow"] = get_tf_version() if is_flax_available(): framework_versions.update( { @@ -186,6 +333,10 @@ def generate_context(framework_name: str) -> _ModelContext: return _ModelContext(framework_name=framework_name, framework_versions=framework_versions) +def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> DictStrAny: + return {"runtime": llm.runtime, "framework": "openllm"} + + _TOKENIZER_PREFIX = "_tokenizer_" @@ -198,6 +349,33 @@ def normalize_attrs_to_model_tokenizer_pair(**attrs: t.Any) -> tuple[DictStrAny, return attrs, tokenizer_attrs +@_overload +def infer_auto_class(implementation: t.Literal["pt"]) -> type[openllm.AutoLLM]: + ... + + +@_overload +def infer_auto_class(implementation: t.Literal["tf"]) -> type[openllm.AutoTFLLM]: + ... + + +@_overload +def infer_auto_class(implementation: t.Literal["flax"]) -> type[openllm.AutoFlaxLLM]: + ... + + +def infer_auto_class(implementation: LiteralRuntime) -> type[BaseAutoLLMClass]: + if implementation == "tf": + from ..models.auto import AutoTFLLM as auto + elif implementation == "flax": + from ..models.auto import AutoFlaxLLM as auto + elif implementation == "pt": + from ..models.auto import AutoLLM as auto + else: + raise RuntimeError(f"Unknown implementation: {implementation} (supported: 'pt', 'flax', 'tf')") + return auto + + # NOTE: The set marks contains a set of modules name # that are available above and are whitelisted # to be included in the extra_objects map. @@ -218,6 +396,7 @@ _import_structure = { "codegen": [], "dantic": [], "representation": ["ReprMixin"], + "lazy": ["LazyModule"], "import_utils": [ "OPTIONAL_DEPENDENCIES", "ENV_VARS_TRUE_VALUES", @@ -248,28 +427,18 @@ if t.TYPE_CHECKING: from . import LazyType as LazyType from . import analytics as analytics from . import bentoml_cattr as bentoml_cattr + from . import cached_contextmanager as cached_contextmanager from . import codegen as codegen from . import configure_logging as configure_logging + from . import configure_server_logging as configure_server_logging from . import copy_file_to_fs_folder as copy_file_to_fs_folder from . import dantic as dantic from . import first_not_none as first_not_none - from . import get_debug_mode as get_debug_mode - from . import get_quiet_mode as get_quiet_mode - from . import gpu_count as gpu_count - from . import lenient_issubclass as lenient_issubclass - from . import non_intrusive_setattr as non_intrusive_setattr - from . import pkg as pkg from . import reserve_free_port as reserve_free_port - from . import resolve_user_filepath as resolve_user_filepath from . import set_debug_mode as set_debug_mode from . import set_quiet_mode as set_quiet_mode - from . import in_notebook as in_notebook - from . import validate_or_create_dir as validate_or_create_dir from . import validate_is_path as validate_is_path - from . import resolve_filepath as resolve_filepath - from . import normalize_attrs_to_model_tokenizer_pair as normalize_attrs_to_model_tokenizer_pair - from . import generate_context as generate_context - from . import field_env_key as field_env_key + from . import validate_or_create_dir as validate_or_create_dir from .import_utils import ENV_VARS_TRUE_VALUES as ENV_VARS_TRUE_VALUES from .import_utils import OPTIONAL_DEPENDENCIES as OPTIONAL_DEPENDENCIES from .import_utils import DummyMetaclass as DummyMetaclass @@ -279,6 +448,9 @@ if t.TYPE_CHECKING: from .import_utils import is_datasets_available as is_datasets_available from .import_utils import is_einops_available as is_einops_available from .import_utils import is_flax_available as is_flax_available + from .import_utils import is_jupyter_available as is_jupyter_available + from .import_utils import is_jupytext_available as is_jupytext_available + from .import_utils import is_notebook_available as is_notebook_available from .import_utils import is_peft_available as is_peft_available from .import_utils import is_tf_available as is_tf_available from .import_utils import is_torch_available as is_torch_available @@ -287,10 +459,6 @@ if t.TYPE_CHECKING: from .import_utils import is_triton_available as is_triton_available from .import_utils import require_backends as require_backends from .import_utils import requires_dependencies as requires_dependencies - from .import_utils import is_jupyter_available as is_jupyter_available - from .import_utils import is_jupytext_available as is_jupytext_available - from .import_utils import is_notebook_available as is_notebook_available - from .lazy import LazyModule as LazyModule from .representation import ReprMixin as ReprMixin else: import sys diff --git a/src/openllm/utils/analytics.py b/src/openllm/utils/analytics.py index a7e311c6..01adc8d9 100644 --- a/src/openllm/utils/analytics.py +++ b/src/openllm/utils/analytics.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Telemetry related for OpenLLM tracking. +"""Telemetry related for OpenLLM tracking. Users can disable this with OPENLLM_DO_NOT_TRACK envvar. """ from __future__ import annotations - import contextlib import functools import os diff --git a/src/openllm/utils/codegen.py b/src/openllm/utils/codegen.py index 57ee131c..27fa6593 100644 --- a/src/openllm/utils/codegen.py +++ b/src/openllm/utils/codegen.py @@ -13,11 +13,14 @@ # limitations under the License. from __future__ import annotations - +import functools +import inspect +import linecache import logging -import os import string +import types import typing as t +from operator import itemgetter from pathlib import Path import orjson @@ -28,17 +31,16 @@ if t.TYPE_CHECKING: import openllm - DictStrAny = dict[str, t.Any] - ListStr = list[str] + from .._types import AnyCallable + from .._types import DictStrAny + from .._types import ListStr + from .._types import P - from attr import _make_method + PartialAny = functools.partial[t.Any] else: - # NOTE: Using internal API from attr here, since we are actually - # allowing subclass of openllm.LLMConfig to become 'attrs'-ish - from attr._make import _make_method - DictStrAny = dict ListStr = list + PartialAny = functools.partial _T = t.TypeVar("_T", bound=t.Callable[..., t.Any]) @@ -53,11 +55,12 @@ class ModelNameFormatter(string.Formatter): model_keyword: t.LiteralString = "__model_name__" def __init__(self, model_name: str): + """The formatter that extends model_name to be formatted the 'service.py'.""" super().__init__() self.model_name = model_name - def vformat(self, format_string: str) -> str: - return super().vformat(format_string, (), {self.model_keyword: self.model_name}) + def vformat(self, format_string: str, *args: t.Any, **attrs: t.Any) -> t.LiteralString: + return t.cast("t.LiteralString", super().vformat(format_string, (), {self.model_keyword: self.model_name})) def can_format(self, value: str) -> bool: try: @@ -117,9 +120,7 @@ _sentinel = object() def has_own_attribute(cls: type[t.Any], attrib_name: t.Any): - """ - Check whether *cls* defines *attrib_name* (and doesn't just inherit it). - """ + """Check whether *cls* defines *attrib_name* (and doesn't just inherit it).""" attr = getattr(cls, attrib_name, _sentinel) if attr is _sentinel: return False @@ -133,9 +134,7 @@ def has_own_attribute(cls: type[t.Any], attrib_name: t.Any): def get_annotations(cls: type[t.Any]) -> DictStrAny: - """ - Get annotations for *cls*. - """ + """Get annotations for *cls*.""" if has_own_attribute(cls, "__annotations__"): return cls.__annotations__ @@ -151,8 +150,7 @@ _classvar_prefixes = ( def is_class_var(annot: str | t.Any) -> bool: - """ - Check whether *annot* is a typing.ClassVar. + """Check whether *annot* is a typing.ClassVar. The string comparison hack is used to avoid evaluating all string annotations which would put attrs-based classes at a performance @@ -168,16 +166,14 @@ def is_class_var(annot: str | t.Any) -> bool: def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T: - """ - Add __module__ and __qualname__ to a *method* if possible. - """ + """Add __module__ and __qualname__ to a *method* if possible.""" try: method_or_cls.__module__ = cls.__module__ except AttributeError: pass try: - method_or_cls.__qualname__ = ".".join((cls.__qualname__, method_or_cls.__name__)) + method_or_cls.__qualname__ = f"{cls.__qualname__}.{method_or_cls.__name__}" except AttributeError: pass @@ -191,6 +187,64 @@ def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str return method_or_cls +def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = ""): + """Exec the script with the given global (globs) and local (locs) variables.""" + bytecode = compile(script, filename, "exec") + eval(bytecode, globs, locs) # noqa: S307 + + +# ported from attrs +def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> AnyCallable: + """Create the method with the script given and return the method object.""" + locs: DictStrAny = {} + + # In order of debuggers like PDB being able to step through the code, + # we add a fake linecache entry. + count = 1 + base_filename = filename + while True: + linecache_tuple = ( + len(script), + None, + script.splitlines(True), + filename, + ) + old_val = linecache.cache.setdefault(filename, linecache_tuple) + if old_val == linecache_tuple: + break + else: + filename = f"{base_filename[:-1]}-{count}>" + count += 1 + + _compile_and_eval(script, globs, locs, filename) + + return locs[name] + + +def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]): + """Create a tuple subclass to hold class attributes. + + The subclass is a bare tuple with properties for names. + + class MyClassAttributes(tuple): + __slots__ = () + x = property(itemgetter(0)) + """ + attr_class_name = f"{cls_name}Attributes" + attr_class_template = [ + f"class {attr_class_name}(tuple):", + " __slots__ = ()", + ] + if attr_names: + for i, attr_name in enumerate(attr_names): + attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))") + else: + attr_class_template.append(" pass") + globs: DictStrAny = {"_attrs_itemgetter": itemgetter, "_attrs_property": property} + _compile_and_eval("\n".join(attr_class_template), globs) + return globs[attr_class_name] + + def generate_unique_filename(cls: type[t.Any], func_name: str): return f"<{cls.__name__} generated {func_name} {cls.__module__}." f"{getattr(cls, '__qualname__', cls.__name__)}>" @@ -203,7 +257,7 @@ def generate_function( globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None, ): - from . import DEBUG + from . import SHOW_CODEGEN script = "def %s(%s):\n %s\n" % ( func_name, @@ -214,7 +268,7 @@ def generate_function( if annotations: meth.__annotations__ = annotations - if DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3: + if SHOW_CODEGEN: logger.info("Generated script for %s:\n\n%s", typ, script) return meth @@ -227,7 +281,8 @@ def make_env_transformer( default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None, ): - from . import dantic, field_env_key + from . import dantic + from . import field_env_key def identity(_: str, x_value: t.Any) -> t.Any: return x_value @@ -268,3 +323,35 @@ def make_env_transformer( globs=globs, annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann}, ) + + +def gen_sdk(func: t.Callable[P, t.Any], name: str | None = None, **attrs: t.Any): + from .representation import ReprMixin + + if name is None: + name = func.__name__.strip("_") + + _signatures = inspect.signature(func).parameters + + def _repr(self: ReprMixin) -> str: + return f"" + + def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]: + return ((k, _signatures[k].annotation) for k in self.__repr_keys__) + + return functools.update_wrapper( + types.new_class( + name, + (PartialAny, ReprMixin), + exec_body=lambda ns: ns.update( + { + "__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]), + "__repr_args__": _repr_args, + "__repr__": _repr, + "__doc__": inspect.cleandoc(t.cast(str, func.__doc__)), + "__module__": "openllm", + } + ), + )(func, **attrs), + func, + ) diff --git a/src/openllm/utils/dantic.py b/src/openllm/utils/dantic.py index 15cbacd5..71f3d0b8 100644 --- a/src/openllm/utils/dantic.py +++ b/src/openllm/utils/dantic.py @@ -14,71 +14,40 @@ """A shim provides usable transition from pydantic to attrs.""" from __future__ import annotations - import functools import importlib import os +import sys import typing as t from enum import Enum import attr import click -import sys import click_option_group as cog import inflection import orjson -from click import ParamType, shell_completion as sc, types as click_types +from click import ParamType +from click import shell_completion as sc +from click import types as click_types import openllm -# NOTE: We need to do this so that overload can register -# correct overloads to typing registry -if hasattr(t, "get_overloads"): - from typing import overload -else: - from typing_extensions import overload - if t.TYPE_CHECKING: from attr import _ValidatorType - from .._types import ClickFunctionWrapper - from .._types import F - from .._types import O_co - from .._types import P + from .._types import ListAny _T = t.TypeVar("_T") -@overload def attrs_to_options( name: str, field: attr.Attribute[t.Any], model_name: str, typ: type[t.Any] | None = None, suffix_generation: bool = False, -) -> F[..., F[..., openllm.LLMConfig]]: - ... - - -@overload -def attrs_to_options( # type: ignore (overlapping overload) - name: str, - field: attr.Attribute[O_co], - model_name: str, - typ: type[t.Any] | None = None, - suffix_generation: bool = False, -) -> F[..., F[P, O_co]]: - ... - - -def attrs_to_options( - name: str, - field: attr.Attribute[t.Any], - model_name: str, - typ: type[t.Any] | None = None, - suffix_generation: bool = False, -) -> t.Callable[..., ClickFunctionWrapper[..., t.Any]]: +): # TODO: support parsing nested attrs class and Union envvar = field.metadata["env"] dasherized = inflection.dasherize(name) @@ -86,6 +55,8 @@ def attrs_to_options( if typ in (None, attr.NOTHING): typ = field.type + if typ is None: + raise RuntimeError(f"Failed to parse type for {name}") full_option_name = f"--{dasherized}" if field.type is bool: @@ -116,7 +87,7 @@ def env_converter(value: t.Any, env: str | None = None) -> t.Any: try: return orjson.loads(value.lower()) except orjson.JSONDecodeError as err: - raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") + raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None return value @@ -132,14 +103,16 @@ def Field( use_default_converter: bool = True, **attrs: t.Any, ): - """A decorator that extends attr.field with additional arguments, which provides the same - interface as pydantic's Field. + """A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field. By default, if both validator and ge are provided, then then ge will be piped into first, then all of the other validator will be run afterwards. Args: + default: The default value for ``dantic.Field``. Defaults to ``None``. ge: Greater than or equal to. Defaults to None. + le: Less than or equal to. Defaults to None. + validator: Optional attrs-compatible validators type. Default to None description: the documentation for the field. Defaults to None. env: the environment variable to read from. Defaults to None. auto_default: a bool indicating whether to use the default value as the environment. @@ -150,7 +123,7 @@ def Field( to True. If set to False, then the default converter will not be used. The default converter converts a given value from the environment variable for this given Field. - **kwargs: The rest of the arguments are passed to attr.field + **attrs: The rest of the arguments are passed to attr.field """ metadata = attrs.pop("metadata", {}) if description is None: @@ -205,13 +178,14 @@ def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType]: """ from . import lenient_issubclass - assert t.get_origin(field_type) is not t.Union, "Unions are not supported" + if t.get_origin(field_type) is t.Union: + raise NotImplementedError("Unions are not supported") # enumeration strings or other Enum derivatives if lenient_issubclass(field_type, Enum): return EnumChoice(enum=field_type, case_sensitive=True) # literals are enum-like with way less functionality if is_literal(field_type): - return LiteralChoice(enum=field_type, case_sensitive=True) + return LiteralChoice(value=field_type, case_sensitive=True) # modules, classes, functions if is_typing(field_type): return ModuleType() @@ -248,6 +222,7 @@ def is_typing(field_type: type) -> bool: def is_literal(field_type: type) -> bool: """Checks whether the given field type is a Literal type or not. + Literals are weird: isinstance and subclass do not work, so you compare the origin with the Literal declaration itself. @@ -266,63 +241,76 @@ class ModuleType(ParamType): def _import_object(self, value: str) -> t.Any: module_name, class_name = value.rsplit(".", maxsplit=1) - assert all(s.isidentifier() for s in module_name.split(".")), f"'{value}' is not a valid module name" - assert class_name.isidentifier(), f"Variable '{class_name}' is not a valid identifier" + if not all(s.isidentifier() for s in module_name.split(".")): + raise ValueError(f"'{value}' is not a valid module name") + if not class_name.isidentifier(): + raise ValueError(f"Variable '{class_name}' is not a valid identifier") module = importlib.import_module(module_name) if class_name: try: return getattr(module, class_name) except AttributeError: - raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") + raise ImportError(f"Module '{module_name}' does not define a '{class_name}' variable.") from None return None - def convert(self, value: str, param: click.Parameter | None, ctx: click.Context | None) -> t.Any: + def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any: try: if isinstance(value, str): return self._import_object(value) return value except Exception as exc: - self.fail(f"'{value}' is not a valid object ({type(exc)}: {str(exc)})", param, ctx) + self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx) class EnumChoice(click.Choice): name = "enum" def __init__(self, enum: Enum, case_sensitive: bool = False): + """Enum type support for click that extends ``click.Choice``. + + Args: + enum: Given enum + case_sensitive: Whether this choice should be case case_sensitive. + """ self.mapping = enum - self.internal_type = enum - super().__init__([e.name for e in self.mapping], case_sensitive) + self.internal_type = type(enum) + choices: ListAny = [e.name for e in enum.__class__] + super().__init__(choices, case_sensitive) def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> Enum: if isinstance(value, self.internal_type): return value result = super().convert(value, param, ctx) if isinstance(result, str): - result = self.mapping[result] + result = self.internal_type[result] return result class LiteralChoice(EnumChoice): name = "literal" - def __init__(self, enum: t.LiteralString, case_sensitive: bool = False): + def __init__(self, value: t.Any, case_sensitive: bool = False): + """Literal support for click.""" # expect every literal value to belong to the same primitive type - values = list(enum.__args__) + values = list(value.__args__) item_type = type(values[0]) - assert all(isinstance(v, item_type) for v in values), f"Field {enum} contains items of different types" + if not all(isinstance(v, item_type) for v in values): + raise ValueError(f"Field {value} contains items of different types.") self.internal_type = item_type self.mapping = {str(v): v for v in values} super(EnumChoice, self).__init__(list(self.mapping.keys()), case_sensitive) -def allows_multiple(field_type: t.Any) -> bool: +def allows_multiple(field_type: type) -> bool: """Checks whether the current type allows for multiple arguments to be provided as input or not. + For containers, it exploits click's support for lists and such to use the same option multiple times to create a complex object: `python run.py --subsets train --subsets test` # becomes `subsets: ["train", "test"]`. + Args: - field_type (type): pydantic type + field_type: pydantic type. Returns: bool: true if it's a composite field (lists, containers and so on), false otherwise @@ -360,8 +348,7 @@ def is_mapping(field_type: type) -> bool: def is_container(field_type: type) -> bool: - """Checks whether the current type is a container type ('contains' other types), like - lists and tuples. + """Checks whether the current type is a container type ('contains' other types), like lists and tuples. Args: field_type: pydantic field type @@ -391,12 +378,13 @@ def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType Returns: ParamType | tuple[ParamType]: single click-compatible type or a tuple """ - assert is_container(field_type), "Field type is not a container" + if not is_container(field_type): + raise ValueError("Field type is not a container type.") args = t.get_args(field_type) # Early out for untyped containers: standard lists, tuples, List[Any] # Use strings when the type is unknown, avoid click's type guessing if len(args) == 0: - return str + return click_types.convert_type(str) # Early out for homogenous containers: Tuple[int], List[str] if len(args) == 1: return parse_single_arg(args[0]) @@ -409,6 +397,7 @@ def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType def parse_single_arg(arg: type) -> ParamType: """Returns the click-compatible type for container origin types. + In this case, returns string when it's not inferrable, a JSON for mappings and the original type itself in every other case (ints, floats and so on). Bytes is a special case, not natively handled by click. @@ -421,13 +410,13 @@ def parse_single_arg(arg: type) -> ParamType: """ # When we don't know the type, we choose 'str' if arg is t.Any: - return str + return click_types.convert_type(str) # For containers and nested models, we use JSON if is_container(arg): return JsonType() if openllm.utils.lenient_issubclass(arg, bytes): return BytesType() - return arg + return click_types.convert_type(arg) class BytesType(ParamType): @@ -439,7 +428,7 @@ class BytesType(ParamType): try: return str.encode(value) except Exception as exc: - self.fail(f"'{value}' is not a valid string ({str(exc)})", param, ctx) + self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx) CYGWIN = sys.platform.startswith("cygwin") @@ -470,17 +459,14 @@ class CudaValueType(ParamType): return var def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]: - """Return a list of - :class:`~click.shell_completion.CompletionItem` objects for the - incomplete value. Most types do not provide completions, but - some do, and this allows custom types to provide custom - completions as well. + """Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value. - :param ctx: Invocation context for this command. - :param param: The parameter that is requesting completion. - :param incomplete: Value being completed. May be empty. + Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well. - .. versionadded:: 8.0 + Args: + ctx: Invocation context for this command. + param: The parameter that is requesting completion. + incomplete: Value being completed. May be empty. """ from ..utils import gpu_count @@ -506,6 +492,7 @@ class CudaValueType(ParamType): return tuple(self.typ(x, param, ctx) for x in value.split(",")) def __repr__(self) -> str: + """CUDA is a click.STRING extension.""" return "STRING" @@ -516,6 +503,11 @@ class JsonType(ParamType): name = "json" def __init__(self, should_load: bool = True) -> None: + """Support JSON type for click.ParamType. + + Args: + should_load: Whether to load the JSON. Default to True. If False, the value won't be converted. + """ super().__init__() self.should_load = should_load @@ -525,4 +517,4 @@ class JsonType(ParamType): try: return orjson.loads(value) except orjson.JSONDecodeError as exc: - self.fail(f"'{value}' is not a valid JSON string ({str(exc)})", param, ctx) + self.fail(f"'{value}' is not a valid JSON string ({exc!s})", param, ctx) diff --git a/src/openllm/utils/dummy_flax_objects.py b/src/openllm/utils/dummy_flax_objects.py index 9e2700fa..4cc64279 100644 --- a/src/openllm/utils/dummy_flax_objects.py +++ b/src/openllm/utils/dummy_flax_objects.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import typing as t from ..utils import DummyMetaclass diff --git a/src/openllm/utils/dummy_pt_and_cpm_kernels_objects.py b/src/openllm/utils/dummy_pt_and_cpm_kernels_objects.py index 88ba77fd..c5ac9d16 100644 --- a/src/openllm/utils/dummy_pt_and_cpm_kernels_objects.py +++ b/src/openllm/utils/dummy_pt_and_cpm_kernels_objects.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import typing as t from ..utils import DummyMetaclass diff --git a/src/openllm/utils/dummy_pt_and_einops_objects.py b/src/openllm/utils/dummy_pt_and_einops_objects.py index 2a577906..f6dcca46 100644 --- a/src/openllm/utils/dummy_pt_and_einops_objects.py +++ b/src/openllm/utils/dummy_pt_and_einops_objects.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import typing as t from ..utils import DummyMetaclass diff --git a/src/openllm/utils/dummy_pt_and_triton_objects.py b/src/openllm/utils/dummy_pt_and_triton_objects.py index 643b7260..39e816dd 100644 --- a/src/openllm/utils/dummy_pt_and_triton_objects.py +++ b/src/openllm/utils/dummy_pt_and_triton_objects.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import typing as t from ..utils import DummyMetaclass diff --git a/src/openllm/utils/dummy_pt_objects.py b/src/openllm/utils/dummy_pt_objects.py index c4cbb88f..9e63a4da 100644 --- a/src/openllm/utils/dummy_pt_objects.py +++ b/src/openllm/utils/dummy_pt_objects.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import typing as t from ..utils import DummyMetaclass diff --git a/src/openllm/utils/dummy_tf_objects.py b/src/openllm/utils/dummy_tf_objects.py index 2d0a8651..bad558ed 100644 --- a/src/openllm/utils/dummy_tf_objects.py +++ b/src/openllm/utils/dummy_tf_objects.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import typing as t from ..utils import DummyMetaclass diff --git a/src/openllm/utils/import_utils.py b/src/openllm/utils/import_utils.py index e4c3eb22..b17b6409 100644 --- a/src/openllm/utils/import_utils.py +++ b/src/openllm/utils/import_utils.py @@ -12,17 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons. -""" +"""Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons.""" from __future__ import annotations - import functools import importlib import importlib.metadata import importlib.util import logging import os +import sys import typing as t from abc import ABCMeta from collections import OrderedDict @@ -38,15 +36,21 @@ from .representation import ReprMixin # NOTE: We need to do this so that overload can register # correct overloads to typing registry -if hasattr(t, "get_overloads"): +if sys.version_info[:2] >= (3, 11): from typing import overload else: from typing_extensions import overload if t.TYPE_CHECKING: BackendOrderredDict = OrderedDict[str, tuple[t.Callable[[], bool], str]] + from .._types import LiteralRuntime from .._types import P + + class _AnnotatedLazyLoader(LazyLoader): + DEFAULT_PROMPT_TEMPLATE: t.LiteralString | None | t.Callable[..., t.LiteralString] + else: + _AnnotatedLazyLoader = LazyLoader BackendOrderredDict = OrderedDict logger = logging.getLogger(__name__) @@ -188,7 +192,7 @@ def is_tf_available(): _tf_available = _tf_version is not None if _tf_available: if _tf_version and version.parse(_tf_version) < version.parse("2"): - logger.info(f"TensorFlow found but with version {_tf_version}. OpenLLM only supports TF 2.x") + logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version) _tf_available = False else: logger.info("Disabling Tensorflow because USE_TORCH is set") @@ -321,8 +325,9 @@ BACKENDS_MAPPING = BackendOrderredDict( class DummyMetaclass(ABCMeta): - """Metaclass for dummy object. It will raises ImportError - generated by ``require_backends`` if users try to access attributes from given class + """Metaclass for dummy object. + + It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class. """ _backends: t.List[str] @@ -368,7 +373,7 @@ class EnvVarMixin(ReprMixin): bettertransformer: str runtime: t.Literal["ggml", "transformers"] - framework_value: t.Literal["pt", "tf", "flax"] + framework_value: LiteralRuntime quantize_value: str | None bettertransformer_value: str | None runtime_value: t.Literal["ggml", "transformers"] @@ -385,17 +390,17 @@ class EnvVarMixin(ReprMixin): @overload def __getitem__(self, item: t.Literal["bettertransformer"]) -> str: ... @overload - def __getitem__(self, item: t.Literal['runtime']) -> str: ... + def __getitem__(self, item: t.Literal["runtime"]) -> str: ... @overload - def __getitem__(self, item: t.Literal['framework_value']) -> t.Literal['pt', 'tf', 'flax']: ... + def __getitem__(self, item: t.Literal["framework_value"]) -> LiteralRuntime: ... @overload - def __getitem__(self, item: t.Literal['quantize_value']) -> str | None: ... + def __getitem__(self, item: t.Literal["quantize_value"]) -> str | None: ... @overload - def __getitem__(self, item: t.Literal['model_id_value']) -> str | None: ... + def __getitem__(self, item: t.Literal["model_id_value"]) -> str | None: ... @overload - def __getitem__(self, item: t.Literal['bettertransformer_value']) -> str | None: ... + def __getitem__(self, item: t.Literal["bettertransformer_value"]) -> str | None: ... @overload - def __getitem__(self, item: t.Literal['runtime_value']) -> t.Literal['ggml', 'transformers']: ... + def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]: ... # fmt: on def __getitem__(self, item: str | t.Any) -> t.Any: if hasattr(self, item): @@ -409,8 +414,8 @@ class EnvVarMixin(ReprMixin): quantize: t.LiteralString | None = None, runtime: t.Literal["ggml", "transformers"] = "transformers", ): - from .._configuration import field_env_key from . import codegen + from .._configuration import field_env_key model_name = inflection.underscore(model_name) @@ -464,5 +469,5 @@ class EnvVarMixin(ReprMixin): return getattr(self.module, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING") @property - def module(self) -> LazyLoader: - return LazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}") + def module(self): + return _AnnotatedLazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}") diff --git a/src/openllm/utils/lazy.py b/src/openllm/utils/lazy.py index 8082496f..75ec30b5 100644 --- a/src/openllm/utils/lazy.py +++ b/src/openllm/utils/lazy.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import importlib import importlib.machinery import itertools @@ -40,10 +39,9 @@ _reserved_namespace = {"__openllm_special__", "__openllm_migration__"} class LazyModule(types.ModuleType): - """ - Module class that surfaces all objects but only performs associated imports when the objects are requested. - This is a direct port from transformers.utils.import_utils._LazyModule for - backwards compatibility with transformers < 4.18 + """Module class that surfaces all objects but only performs associated imports when the objects are requested. + + This is a direct port from transformers.utils.import_utils._LazyModule for backwards compatibility with transformers < 4.18. This is an extension a more powerful LazyLoader. """ @@ -56,8 +54,22 @@ class LazyModule(types.ModuleType): module_file: str, import_structure: dict[str, list[str]], module_spec: importlib.machinery.ModuleSpec | None = None, + doc: str | None = None, extra_objects: dict[str, t.Any] | None = None, ): + """Lazily load this module as an object. + + It does instantiate a __all__ and __dir__ for IDE support + + Args: + name: module name + module_file: the given file. Often default to 'globals()['__file__']' + import_structure: A dictionary of module and its corresponding attributes that can be loaded from given 'module' + module_spec: __spec__ of the lazily loaded module + doc: Optional docstring for this module. + extra_objects: Any additional objects that this module can also be accessed. Useful for additional metadata as well + as any locals() functions + """ super().__init__(name) self._modules = set(import_structure.keys()) self._class_to_module: dict[str, str] = {} @@ -70,24 +82,22 @@ class LazyModule(types.ModuleType): self.__file__ = module_file self.__spec__ = module_spec self.__path__ = [os.path.dirname(module_file)] + self.__doc__ = doc self._objects = _extra_objects self._name = name self._import_structure = import_structure - # Needed for autocompletion in an IDE def __dir__(self): + """Needed for autocompletion in an IDE.""" result = t.cast("list[str]", super().__dir__()) # The elements of self.__all__ that are submodules may or # may not be in the dir already, depending on whether # they have been accessed or not. So we only add the # elements of self.__all__ that are not already in the dir. - for attribute in self.__all__: - if attribute not in result: - result.append(attribute) - return result + return result + [i for i in self.__all__ if i not in result] def __getitem__(self, key: str) -> t.Any: - # currently, this is reserved to only internal uses and users shouldn't use this. + """This is reserved to only internal uses and users shouldn't use this.""" if self._objects.get("__openllm_special__") is None: raise UsageNotAllowedError(f"'{self._name}' is not allowed to be used as a dict.") _special_mapping = self._objects.get("__openllm_special__", {}) @@ -101,6 +111,10 @@ class LazyModule(types.ModuleType): raise KeyError(f"Failed to lookup '{key}' in '{self._name}'") from e def __getattr__(self, name: str) -> t.Any: + """Equivocal __getattr__ implementation. + + It checks from _objects > _modules and does it recursively. + """ if name in _reserved_namespace: raise ForbiddenAttributeError( f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified." @@ -111,6 +125,7 @@ class LazyModule(types.ModuleType): warnings.warn( f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead", DeprecationWarning, + stacklevel=3, ) return getattr(self, cur_value) if name in self._objects: @@ -136,4 +151,5 @@ class LazyModule(types.ModuleType): ) from e def __reduce__(self): + """This is to ensure any given module is pickle-able.""" return (self.__class__, (self._name, self.__file__, self._import_structure)) diff --git a/src/openllm/utils/representation.py b/src/openllm/utils/representation.py index bf267c0a..0294e783 100644 --- a/src/openllm/utils/representation.py +++ b/src/openllm/utils/representation.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import typing as t from abc import abstractmethod @@ -27,6 +26,7 @@ if t.TYPE_CHECKING: class ReprMixin: """This class display possible representation of given class. + It can be used for implementing __rich_pretty__ and __pretty__ methods in the future. Most subclass needs to implement a __repr_keys__ property. @@ -41,12 +41,20 @@ class ReprMixin: """This can be overriden by base class using this mixin.""" def __repr__(self) -> str: + """The `__repr__` for any subclass of Mixin. + + It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict. + """ from . import bentoml_cattr serialized = {k: bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()} return f"{self.__class__.__name__} {orjson.dumps(serialized, option=orjson.OPT_INDENT_2).decode()}" def __str__(self) -> str: + """The string representation of the given Mixin subclass. + + It will contains all of the attributes from __repr_keys__ + """ return self.__repr_str__(" ") def __repr_name__(self) -> str: @@ -54,7 +62,12 @@ class ReprMixin: return self.__class__.__name__ def __repr_str__(self, join_str: str) -> str: - return join_str.join(repr(v) if a is None else f"{a}={repr(v)}" for a, v in self.__repr_args__()) + """To be used with __str__.""" + return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__()) def __repr_args__(self) -> ReprArgs: + """This can also be overriden by base class using this mixin. + + By default it does a getattr of the current object from __repr_keys__. + """ return ((k, getattr(self, k)) for k in self.__repr_keys__) diff --git a/src/openllm_client/__init__.py b/src/openllm_client/__init__.py index 5daee98a..3f1e7f20 100644 --- a/src/openllm_client/__init__.py +++ b/src/openllm_client/__init__.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -The actual client implementation. Use ``openllm.client`` instead. +"""The actual client implementation. + +Use ``openllm.client`` instead. This holds the implementation of the client, which is used to communicate with the OpenLLM server. It is used to send requests to the server, and receive responses. """ diff --git a/src/openllm_client/_prompt.py b/src/openllm_client/_prompt.py index 82b69450..42659d83 100644 --- a/src/openllm_client/_prompt.py +++ b/src/openllm_client/_prompt.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations - import typing as t import attr @@ -42,7 +41,7 @@ class PromptTemplate: input_variables: t.Sequence[str] def to_str(self, __partial_dict__: PartialDict | None = None, **attrs: str) -> str: - """Generate a prompt from the template and input variables""" + """Generate a prompt from the template and input variables.""" if __partial_dict__: return _default_formatter.vformat(self.template, (), __partial_dict__) if not attrs: @@ -58,7 +57,7 @@ class PromptTemplate: @classmethod def from_default(cls, model: str, /, **prompt_attrs: t.Any) -> PromptTemplate: - template = getattr(openllm.utils.EnvVarMixin(model).module, "DEFAULT_PROMPT_TEMPLATE") + template = openllm.utils.EnvVarMixin(model).module.DEFAULT_PROMPT_TEMPLATE if template is None: raise ValueError(f"Model {model} does not have a default prompt template.") if callable(template): diff --git a/src/openllm_client/runtimes/__init__.py b/src/openllm_client/runtimes/__init__.py index c4b22214..29874ac8 100644 --- a/src/openllm_client/runtimes/__init__.py +++ b/src/openllm_client/runtimes/__init__.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Client that supports REST/gRPC protocol to interact with a LLMServer.""" +from .grpc import AsyncGrpcClient as AsyncGrpcClient from .grpc import GrpcClient as GrpcClient +from .http import AsyncHTTPClient as AsyncHTTPClient from .http import HTTPClient as HTTPClient diff --git a/src/openllm_client/runtimes/base.py b/src/openllm_client/runtimes/base.py index dc8d00e0..6e3d2cb1 100644 --- a/src/openllm_client/runtimes/base.py +++ b/src/openllm_client/runtimes/base.py @@ -13,29 +13,30 @@ # limitations under the License. from __future__ import annotations - import asyncio +import logging +import sys import typing as t from abc import abstractmethod +from http import HTTPStatus from urllib.parse import urljoin import httpx import bentoml import openllm -import logging # NOTE: We need to do this so that overload can register # correct overloads to typing registry -if hasattr(t, "get_overloads"): +if sys.version_info[:2] >= (3, 11): from typing import overload else: from typing_extensions import overload if t.TYPE_CHECKING: import transformers - from openllm.models.auto.factory import _BaseAutoLLMClass + from openllm._types import LiteralRuntime class AnnotatedClient(bentoml.client.Client): def health(self, *args: t.Any, **attrs: t.Any) -> t.Any: @@ -44,12 +45,6 @@ if t.TYPE_CHECKING: async def async_health(self) -> t.Any: ... - def call(self, name: str, inputs: t.Any, **attrs: t.Any) -> t.Any: - ... - - async def acall(self, name: str, inputs: t.Any, **attrs: t.Any) -> t.Any: - ... - def generate_v1(self, qa: openllm.GenerationInput) -> dict[str, t.Any]: ... @@ -70,7 +65,10 @@ def in_async_context() -> bool: return False -class ClientMixin: +T = t.TypeVar("T") + + +class ClientMeta(t.Generic[T]): _api_version: str _client_class: type[bentoml.client.Client] @@ -84,9 +82,9 @@ class ClientMixin: def __init__(self, address: str, timeout: int = 30): self._address = address self._timeout = timeout - assert self._host and self._port, "Make sure to setup _host and _port based on your client implementation." def __init_subclass__(cls, *, client_type: t.Literal["http", "grpc"] = "http", api_version: str = "v1"): + """Initialise subclass for HTTP and gRPC client type.""" cls._client_class = bentoml.client.HTTPClient if client_type == "http" else bentoml.client.GrpcClient cls._api_version = api_version @@ -102,7 +100,7 @@ class ClientMixin: return self.__agent__ @property - def _metadata(self) -> dict[str, t.Any]: + def _metadata(self) -> T: if in_async_context(): return httpx.post(urljoin(self._address, f"/{self._api_version}/metadata")).json() return self.call("metadata") @@ -114,7 +112,7 @@ class ClientMixin: @property @abstractmethod - def framework(self) -> t.Literal["pt", "flax", "tf"]: + def framework(self) -> LiteralRuntime: raise NotImplementedError @property @@ -135,10 +133,7 @@ class ClientMixin: @property def llm(self) -> openllm.LLM[t.Any, t.Any]: if self.__llm__ is None: - self.__llm__ = t.cast( - "_BaseAutoLLMClass", - openllm[self.framework], # type: ignore (internal API) - ).for_model(self.model_name) + self.__llm__ = openllm.infer_auto_class(self.framework).for_model(self.model_name) return self.__llm__ @property @@ -171,7 +166,7 @@ class ClientMixin: ... -class BaseClient(ClientMixin): +class BaseClient(ClientMeta[T]): def health(self) -> t.Any: raise NotImplementedError @@ -183,22 +178,32 @@ class BaseClient(ClientMixin): def query(self, prompt: str, *, return_raw_response: t.Literal[True] = ..., **attrs: t.Any) -> dict[str, t.Any]: ... - def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str: - return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare(prompt, **attrs) + @overload + def query(self, prompt: str, *, return_attrs: t.Literal[True] = True, **attrs: t.Any) -> openllm.GenerationOutput: + ... + + def query(self, prompt: str, **attrs: t.Any) -> openllm.GenerationOutput | dict[str, t.Any] | str: + # NOTE: We set use_default_prompt_template to False for now. + use_default_prompt_template = attrs.pop("use_default_prompt_template", False) + return_attrs = attrs.pop("return_attrs", False) + return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare( + prompt, use_default_prompt_template=use_default_prompt_template, **attrs + ) inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)) if in_async_context(): result = httpx.post( urljoin(self._address, f"/{self._api_version}/generate"), - json=openllm.utils.bentoml_cattr.unstructure(inputs), + json=inputs.model_dump(), timeout=self.timeout, ).json() else: - result = self.call("generate", inputs) + result = self.call("generate", inputs.model_dump()) r = self.postprocess(result) + if return_attrs: + return r if return_raw_response: return openllm.utils.bentoml_cattr.unstructure(r) - return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs) def ask_agent( @@ -235,10 +240,21 @@ class BaseClient(ClientMixin): raise NotImplementedError -class BaseAsyncClient(ClientMixin): +class BaseAsyncClient(ClientMeta[T]): async def health(self) -> t.Any: raise NotImplementedError + @overload + async def query( + self, + prompt: str, + *, + return_attrs: t.Literal[True] = True, + return_raw_response: bool | None = ..., + **attrs: t.Any, + ) -> openllm.GenerationOutput: + ... + @overload async def query(self, prompt: str, *, return_raw_response: t.Literal[False] = ..., **attrs: t.Any) -> str: ... @@ -249,19 +265,21 @@ class BaseAsyncClient(ClientMixin): ) -> dict[str, t.Any]: ... - async def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str: + async def query(self, prompt: str, **attrs: t.Any) -> dict[str, t.Any] | str | openllm.GenerationOutput: # NOTE: We set use_default_prompt_template to False for now. use_default_prompt_template = attrs.pop("use_default_prompt_template", False) + return_attrs = attrs.pop("return_attrs", False) return_raw_response, prompt, generate_kwargs, postprocess_kwargs = self.prepare( prompt, use_default_prompt_template=use_default_prompt_template, **attrs ) inputs = openllm.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)) - res = await self.acall("generate", inputs) + res = await self.acall("generate", inputs.model_dump()) r = self.postprocess(res) + if return_attrs: + return r if return_raw_response: return openllm.utils.bentoml_cattr.unstructure(r) - return self.llm.postprocess_generate(prompt, r.responses, **postprocess_kwargs) async def ask_agent( @@ -273,13 +291,17 @@ class BaseAsyncClient(ClientMixin): agent_type: t.LiteralString = "hf", **attrs: t.Any, ) -> t.Any: - """Async version of agent.run""" + """Async version of agent.run.""" if agent_type == "hf": return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs) else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'") async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + if not openllm.utils.is_transformers_supports_agent(): + raise RuntimeError( + "This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0" + ) if len(args) > 1: raise ValueError("'args' should only take one positional argument.") task = kwargs.pop("task", args[0]) @@ -293,7 +315,7 @@ class BaseAsyncClient(ClientMixin): _hf_agent = self._hf_agent - prompt = _hf_agent.format_prompt(task) + prompt = t.cast(str, _hf_agent.format_prompt(task)) stop = ["Task:"] async with httpx.AsyncClient(timeout=httpx.Timeout(self.timeout)) as client: response = await client.post( @@ -303,7 +325,7 @@ class BaseAsyncClient(ClientMixin): "parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop}, }, ) - if response.status_code != 200: + if response.status_code != HTTPStatus.OK: raise ValueError(f"Error {response.status_code}: {response.json()}") result = response.json()[0]["generated_text"] diff --git a/src/openllm_client/runtimes/grpc.py b/src/openllm_client/runtimes/grpc.py index b32c91c0..e5b4e2d9 100644 --- a/src/openllm_client/runtimes/grpc.py +++ b/src/openllm_client/runtimes/grpc.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import asyncio import logging import typing as t @@ -27,46 +26,51 @@ from .base import BaseClient if t.TYPE_CHECKING: - import grpc_health.v1.health_pb2 as health_pb2 + from grpc_health.v1 import health_pb2 from bentoml.grpc.v1.service_pb2 import Response + from openllm._types import LiteralRuntime logger = logging.getLogger(__name__) class GrpcClientMixin: - _metadata: Response + if t.TYPE_CHECKING: + + @property + def _metadata(self) -> Response: + ... @property def model_name(self) -> str: try: return self._metadata.json.struct_value.fields["model_name"].string_value except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property - def framework(self) -> t.Literal["pt", "flax", "tf"]: + def framework(self) -> LiteralRuntime: try: value = self._metadata.json.struct_value.fields["framework"].string_value if value not in ("pt", "flax", "tf"): raise KeyError return value except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def timeout(self) -> int: try: return int(self._metadata.json.struct_value.fields["timeout"].number_value) except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def model_id(self) -> str: try: return self._metadata.json.struct_value.fields["model_id"].string_value except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def configuration(self) -> dict[str, t.Any]: @@ -74,7 +78,7 @@ class GrpcClientMixin: v = self._metadata.json.struct_value.fields["configuration"].string_value return orjson.loads(v) except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None def postprocess(self, result: Response | dict[str, t.Any]) -> openllm.GenerationOutput: if isinstance(result, dict): @@ -85,7 +89,7 @@ class GrpcClientMixin: return openllm.GenerationOutput(**MessageToDict(result.json, preserving_proto_field_name=True)) -class GrpcClient(GrpcClientMixin, BaseClient, client_type="grpc"): +class GrpcClient(GrpcClientMixin, BaseClient["Response"], client_type="grpc"): def __init__(self, address: str, timeout: int = 30): self._host, self._port = address.split(":") super().__init__(address, timeout) @@ -94,7 +98,7 @@ class GrpcClient(GrpcClientMixin, BaseClient, client_type="grpc"): return asyncio.run(self._cached.health("bentoml.grpc.v1.BentoService")) -class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient, client_type="grpc"): +class AsyncGrpcClient(GrpcClientMixin, BaseAsyncClient["Response"], client_type="grpc"): def __init__(self, address: str, timeout: int = 30): self._host, self._port = address.split(":") super().__init__(address, timeout) diff --git a/src/openllm_client/runtimes/http.py b/src/openllm_client/runtimes/http.py index 05349379..7859d8ca 100644 --- a/src/openllm_client/runtimes/http.py +++ b/src/openllm_client/runtimes/http.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import logging import typing as t from urllib.parse import urlparse @@ -26,52 +25,63 @@ from .base import BaseAsyncClient from .base import BaseClient +if t.TYPE_CHECKING: + from openllm._types import DictStrAny + from openllm._types import LiteralRuntime +else: + DictStrAny = dict + + logger = logging.getLogger(__name__) class HTTPClientMixin: - _metadata: dict[str, t.Any] + if t.TYPE_CHECKING: + + @property + def _metadata(self) -> DictStrAny: + ... @property def model_name(self) -> str: try: return self._metadata["model_name"] except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def model_id(self) -> str: try: return self._metadata["model_name"] except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property - def framework(self) -> t.Literal["pt", "flax", "tf"]: + def framework(self) -> LiteralRuntime: try: return self._metadata["framework"] except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def timeout(self) -> int: try: return self._metadata["timeout"] except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None @property def configuration(self) -> dict[str, t.Any]: try: return orjson.loads(self._metadata["configuration"]) except KeyError: - raise RuntimeError("Malformed service endpoint. (Possible malicious)") + raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput: return openllm.GenerationOutput(**result) -class HTTPClient(HTTPClientMixin, BaseClient): +class HTTPClient(HTTPClientMixin, BaseClient[DictStrAny]): def __init__(self, address: str, timeout: int = 30): address = address if "://" in address else "http://" + address self._host, self._port = urlparse(address).netloc.split(":") @@ -81,7 +91,7 @@ class HTTPClient(HTTPClientMixin, BaseClient): return self._cached.health() -class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient): +class AsyncHTTPClient(HTTPClientMixin, BaseAsyncClient[DictStrAny]): def __init__(self, address: str, timeout: int = 30): address = address if "://" in address else "http://" + address self._host, self._port = urlparse(address).netloc.split(":") diff --git a/tests/_strategies/_configuration.py b/tests/_strategies/_configuration.py index 2615bdbe..823d8cb7 100644 --- a/tests/_strategies/_configuration.py +++ b/tests/_strategies/_configuration.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import logging import typing as t @@ -62,11 +61,11 @@ def make_llm_config( lines.append(f' __config__ = {{ {", ".join(_config_args)} }}') if fields is not None: for field, type_, default in fields: - lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({repr(default)})") + lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({default!r})") if generation_fields is not None: generation_lines = ["class GenerationConfig:"] for field, default in generation_fields: - generation_lines.append(f" {field} = {repr(default)}") + generation_lines.append(f" {field} = {default!r}") lines.extend((" " + line for line in generation_lines)) script = "\n".join(lines) diff --git a/tests/test_client.py b/tests/client_test.py similarity index 100% rename from tests/test_client.py rename to tests/client_test.py diff --git a/tests/test_configuration.py b/tests/configuration_test.py similarity index 88% rename from tests/test_configuration.py rename to tests/configuration_test.py index a0302371..2671ea3a 100644 --- a/tests/test_configuration.py +++ b/tests/configuration_test.py @@ -13,9 +13,9 @@ # limitations under the License. """All configuration-related tests for openllm.LLMConfig. This will include testing -for ModelEnv construction and parsing environment variables.""" +for ModelEnv construction and parsing environment variables. +""" from __future__ import annotations - import contextlib import logging import os @@ -125,29 +125,20 @@ def test_complex_struct_dump( generation_fields=(("temperature", temperature),), ) sent = cl_() - assert ( - sent.model_dump()["field1"] == field1 and sent.model_dump()["generation_config"]["temperature"] == temperature - ) - assert ( - sent.model_dump(flatten=True)["field1"] == field1 - and sent.model_dump(flatten=True)["temperature"] == temperature - ) + assert sent.model_dump()["field1"] == field1 + assert sent.model_dump()["generation_config"]["temperature"] == temperature + assert sent.model_dump(flatten=True)["field1"] == field1 + assert sent.model_dump(flatten=True)["temperature"] == temperature passed = cl_(field1=input_field1, temperature=input_temperature) - assert ( - passed.model_dump()["field1"] == input_field1 - and passed.model_dump()["generation_config"]["temperature"] == input_temperature - ) - assert ( - passed.model_dump(flatten=True)["field1"] == input_field1 - and passed.model_dump(flatten=True)["temperature"] == input_temperature - ) + assert passed.model_dump()["field1"] == input_field1 + assert passed.model_dump()["generation_config"]["temperature"] == input_temperature + assert passed.model_dump(flatten=True)["field1"] == input_field1 + assert passed.model_dump(flatten=True)["temperature"] == input_temperature pas_nested = cl_(generation_config={"temperature": input_temperature}, field1=input_field1) - assert ( - pas_nested.model_dump()["field1"] == input_field1 - and pas_nested.model_dump()["generation_config"]["temperature"] == input_temperature - ) + assert pas_nested.model_dump()["field1"] == input_field1 + assert pas_nested.model_dump()["generation_config"]["temperature"] == input_temperature @contextlib.contextmanager @@ -207,7 +198,7 @@ def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPat @given(model_settings()) -@pytest.mark.parametrize("return_dict,typ", [(True, DictStrAny), (False, transformers.GenerationConfig)]) +@pytest.mark.parametrize(("return_dict", "typ"), [(True, DictStrAny), (False, transformers.GenerationConfig)]) def test_conversion_to_transformers(return_dict: bool, typ: type[t.Any], gen_settings: ModelSettings): cl_ = make_llm_config("ConversionLLM", gen_settings) assert isinstance(cl_().to_generation_config(return_as_dict=return_dict), typ) diff --git a/tests/conftest.py b/tests/conftest.py index 9d723ae7..7fad096e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,9 +13,55 @@ # limitations under the License. from __future__ import annotations +import itertools +import typing as t import pytest +import openllm + + +if t.TYPE_CHECKING: + from openllm._types import LiteralRuntime + + +_FRAMEWORK_MAPPING = {"flan_t5": "google/flan-t5-small", "opt": "facebook/opt-125m"} +_PROMPT_MAPPING = { + "qa": "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?", +} + + +def parametrise_local_llm( + model: str, +) -> t.Generator[tuple[str, openllm.LLMRunner | openllm.LLM[t.Any, t.Any]], None, None]: + if model not in _FRAMEWORK_MAPPING: + pytest.skip(f"'{model}' is not yet supported in framework testing.") + + runtime_impl: tuple[LiteralRuntime, ...] = tuple() + if model in openllm.MODEL_MAPPING_NAMES: + runtime_impl += ("pt",) + if model in openllm.MODEL_FLAX_MAPPING_NAMES: + runtime_impl += ("flax",) + if model in openllm.MODEL_TF_MAPPING_NAMES: + runtime_impl += ("tf",) + + for framework, prompt in itertools.product(runtime_impl, _PROMPT_MAPPING.keys()): + llm = openllm.Runner( + model, + model_id=_FRAMEWORK_MAPPING[model], + ensure_available=True, + implementation=framework, + init_local=True, + ) + yield prompt, llm + + +def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: + if "prompt" in metafunc.fixturenames and "llm" in metafunc.fixturenames: + metafunc.parametrize( + "prompt,llm", [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])] + ) + def pytest_sessionfinish(session: pytest.Session, exitstatus: int): # If no tests are collected, pytest exists with code 5, which makes the CI fail. diff --git a/tests/test_llm.py b/tests/llm_test.py similarity index 60% rename from tests/test_llm.py rename to tests/llm_test.py index 6a5f9882..37ad3d78 100644 --- a/tests/test_llm.py +++ b/tests/llm_test.py @@ -13,12 +13,16 @@ # limitations under the License. from __future__ import annotations - import typing as t + import openllm -import pytest from openllm._llm import make_tag + +if t.TYPE_CHECKING: + import pytest + + HF_INTERNAL_T5_TESTING = "hf-internal-testing/tiny-random-t5" @@ -31,21 +35,6 @@ def patch_hash_from_file(_: str, algorithm: t.LiteralString = "sha1") -> str: return "d88a1a40e354a0c7fa6f9055938594e6a4c712e0" -def test_tag_generation_from_custom_path( - tmp_path_factory: pytest.TempPathFactory, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture -): - monkeypatch.setattr(openllm._llm, "generate_hash_from_file", patch_hash_from_file) - local_path = tmp_path_factory.mktemp("local_t5") - llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING, ensure_available=True) - llm.save_pretrained(local_path) - - with caplog.at_level("WARNING"): - tag = make_tag(local_path.resolve().__fspath__()) - - assert tag.version == "d88a1a40e354a0c7fa6f9055938594e6a4c712e0" - assert "Given 'model_id" in caplog.text - - def test_tag_generation_quiet_log(tmp_path_factory: pytest.TempPathFactory, caplog: pytest.LogCaptureFixture): local_path = tmp_path_factory.mktemp("local_t5") llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING, ensure_available=True) @@ -54,12 +43,3 @@ def test_tag_generation_quiet_log(tmp_path_factory: pytest.TempPathFactory, capl with caplog.at_level("WARNING"): make_tag(local_path.resolve().__fspath__(), quiet=True) assert not caplog.text - - -def test_tag_generation_debug_log(caplog: pytest.LogCaptureFixture): - with caplog.at_level("DEBUG"): - make_tag(HF_INTERNAL_T5_TESTING) - assert ( - "The full tag to be saved under model store: 'pt-hf-internal-testing-tiny-random-t5:2f582cd79ed5795b71539951d237945bc1c5ac7e'" - in caplog.text - ) diff --git a/tests/models/__snapshots__/opt_test/test_opt_125m[container].json b/tests/models/__snapshots__/opt_test/test_opt_125m[container].json new file mode 100644 index 00000000..0727c509 --- /dev/null +++ b/tests/models/__snapshots__/opt_test/test_opt_125m[container].json @@ -0,0 +1,34 @@ +{ + "configuration": { + "format_outputs": false, + "generation_config": { + "diversity_penalty": 0.0, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "encoder_repetition_penalty": 1.0, + "epsilon_cutoff": 0.0, + "eta_cutoff": 0.0, + "length_penalty": 1.0, + "max_new_tokens": 20, + "min_length": 0, + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "remove_invalid_values": false, + "renormalize_logits": false, + "repetition_penalty": 1.0, + "temperature": 0.75, + "top_k": 15, + "top_p": 1.0, + "typical_p": 1.0, + "use_cache": true + } + }, + "responses": [ + "What is Deep learning?\nDeep learning is a new way of studying the content and making an informed decision. It is the" + ] +} \ No newline at end of file diff --git a/tests/models/conftest.py b/tests/models/conftest.py index ac1f4060..60b70170 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -13,45 +13,303 @@ # limitations under the License. from __future__ import annotations - -import types +import asyncio +import contextlib +import functools +import logging +import sys +import time import typing as t +from abc import ABC +from abc import abstractmethod +import attr +import docker +import docker.errors +import docker.types +import orjson import pytest +from syrupy.extensions.json import JSONSnapshotExtension import openllm +from openllm._llm import normalise_model_name +logger = logging.getLogger(__name__) + if t.TYPE_CHECKING: - from openllm.models.auto.factory import _BaseAutoLLMClass + import subprocess -_FRAMEWORK_MAPPING = {"flan_t5": "google/flan-t5-small", "opt": "facebook/opt-125m"} -_PROMPT_MAPPING = { - "qa": "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?", - "default": "What is the weather in SF?", -} + from openllm_client.runtimes.base import BaseAsyncClient + from syrupy.assertion import SnapshotAssertion + from syrupy.types import PropertyFilter + from syrupy.types import PropertyMatcher + from syrupy.types import SerializableData + from syrupy.types import SerializedData + + from openllm._configuration import GenerationConfig + from openllm._types import DictStrAny + from openllm._types import ListAny + +else: + DictStrAny = dict + ListAny = list -def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: - models, fname = t.cast(types.ModuleType, metafunc.module).__name__.partition(".")[-1].split(".")[1:] +class ResponseComparator(JSONSnapshotExtension): + def serialize( + self, + data: SerializableData, + *, + exclude: PropertyFilter | None = None, + matcher: PropertyMatcher | None = None, + ) -> SerializedData: + if openllm.utils.LazyType(ListAny).isinstance(data): + data = [d.unmarshaled for d in data] + else: + data = data.unmarshaled + data = self._filter(data=data, depth=0, path=(), exclude=exclude, matcher=matcher) + return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode() - if "tf" in fname: - framework = "tf" - elif "flax" in fname: - framework = "flax" + def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool: + def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]: + try: + data = orjson.loads(data) + except orjson.JSONDecodeError as err: + raise ValueError(f"Failed to decode JSON data: {data}") from err + if openllm.utils.LazyType(DictStrAny).isinstance(data): + return openllm.GenerationOutput(**data) + elif openllm.utils.LazyType(ListAny).isinstance(data): + return [openllm.GenerationOutput(**d) for d in data] + else: + raise NotImplementedError(f"Data {data} has unsupported type.") + + serialized_data = convert_data(serialized_data) + snapshot_data = convert_data(snapshot_data) + + if openllm.utils.LazyType(ListAny).isinstance(serialized_data): + serialized_data = [serialized_data] + if openllm.utils.LazyType(ListAny).isinstance(snapshot_data): + snapshot_data = [snapshot_data] + + def eq_config(s: GenerationConfig, t: GenerationConfig) -> bool: + return s == t + + def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool: + return ( + len(s.responses) == len(t.responses) + and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) + and eq_config(s.marshaled_config, t.marshaled_config) + ) + + return len(serialized_data) == len(snapshot_data) and all( + [eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)] + ) + + +@pytest.fixture() +def response_snapshot(snapshot: SnapshotAssertion): + return snapshot.use_extension(ResponseComparator) + + +@attr.define(init=False) +class _Handle(ABC): + port: int + deployment_mode: t.Literal["container", "local"] + + client: BaseAsyncClient[t.Any] = attr.field(init=False) + + if t.TYPE_CHECKING: + + def __attrs_init__(self, *args: t.Any, **attrs: t.Any): + ... + + def __attrs_post_init__(self): + self.client = openllm.client.AsyncHTTPClient(f"http://localhost:{self.port}") + + @abstractmethod + def status(self) -> bool: + raise NotImplementedError + + async def health(self, timeout: int = 240): + start_time = time.time() + while time.time() - start_time < timeout: + if not self.status(): + raise RuntimeError(f"Failed to initialise {self.__class__.__name__}") + await self.client.health() + try: + await self.client.query("sanity") + return + except Exception: + time.sleep(1) + raise RuntimeError(f"Handle failed to initialise within {timeout} seconds.") + + +@attr.define(init=False) +class LocalHandle(_Handle): + process: subprocess.Popen[bytes] + + def __init__( + self, + process: subprocess.Popen[bytes], + port: int, + deployment_mode: t.Literal["container", "local"], + ): + self.__attrs_init__(port, deployment_mode, process) + + def status(self) -> bool: + return self.process.poll() is None + + +class HandleProtocol(t.Protocol): + @contextlib.contextmanager + def __call__( + *, + model: str, + model_id: str, + image_tag: str, + quantize: t.AnyStr | None = None, + ) -> t.Generator[_Handle, None, None]: + ... + + +@attr.define(init=False) +class DockerHandle(_Handle): + container_name: str + docker_client: docker.DockerClient + + def __init__( + self, + docker_client: docker.DockerClient, + container_name: str, + port: int, + deployment_mode: t.Literal["container", "local"], + ): + self.__attrs_init__(port, deployment_mode, container_name, docker_client) + + def status(self) -> bool: + container = self.docker_client.containers.get(self.container_name) + return container.status in ["running", "created"] + + +@contextlib.contextmanager +def _local_handle( + model: str, + model_id: str, + image_tag: str, + deployment_mode: t.Literal["container", "local"], + quantize: t.Literal["int8", "int4", "gptq"] | None = None, + *, + _serve_grpc: bool = False, +): + with openllm.utils.reserve_free_port() as port: + pass + + if not _serve_grpc: + proc = openllm.start( + model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True + ) else: - framework = "pt" + proc = openllm.start_grpc( + model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True + ) - llm, runner_kwargs = t.cast( - "_BaseAutoLLMClass", - openllm[framework], # type: ignore - ).for_model(models, model_id=_FRAMEWORK_MAPPING[models], return_runner_kwargs=True, ensure_available=True) - llm.ensure_model_id_exists() - if "runner" in metafunc.function.__name__: - llm = llm.to_runner(**runner_kwargs) - llm.init_local(quiet=True) + yield LocalHandle(proc, port, deployment_mode) + proc.terminate() + proc.wait(60) - if "qa" in metafunc.fixturenames: - metafunc.parametrize("prompt,llm,qa", [(_PROMPT_MAPPING["qa"], llm, True)]) + process_output = proc.stdout.read() + print(process_output, file=sys.stderr) + + proc.stdout.close() + if proc.stderr: + proc.stderr.close() + + +@contextlib.contextmanager +def _container_handle( + model: str, + model_id: str, + image_tag: str, + deployment_mode: t.Literal["container", "local"], + quantize: t.Literal["int8", "int4", "gptq"] | None = None, + *, + _serve_grpc: bool = False, +): + envvar = openllm.utils.EnvVarMixin(model) + + with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port: + pass + container_name = f"openllm-{model}-{normalise_model_name(model_id)}".replace("-", "_") + client = docker.from_env() + try: + container = client.containers.get(container_name) + container.stop() + container.wait() + container.remove() + except docker.errors.NotFound: + pass + + args = ["serve" if not _serve_grpc else "serve-grpc"] + + env: DictStrAny = {} + + if quantize is not None: + env[envvar.quantize] = quantize + + available = openllm.utils.gpu_count() + gpus = len(available) if len(available) > 0 else -1 + devs = [docker.types.DeviceRequest(count=gpus, capabilities=[["gpu"]])] if gpus > 0 else None + + container = client.containers.run( + image_tag, + command=args, + name=container_name, + environment=env, + auto_remove=False, + detach=True, + device_requests=devs, + ports={"3000/tcp": port, "3001/tcp": prom_port}, + ) + + yield DockerHandle(client, container.name, port, deployment_mode) + + try: + container.stop() + container.wait() + except docker.errors.NotFound: + pass + + container_output = container.logs().decode("utf-8") + print(container_output, file=sys.stderr) + + container.remove() + + +@pytest.fixture(scope="session", autouse=True) +def clean_context() -> t.Generator[contextlib.ExitStack, None, None]: + stack = contextlib.ExitStack() + yield stack + stack.close() + + +@pytest.fixture(scope="module") +def el() -> t.Generator[asyncio.AbstractEventLoop, None, None]: + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@pytest.fixture(params=["container", "local"], scope="session") +def deployment_mode(request: pytest.FixtureRequest) -> str: + return request.param + + +@pytest.fixture(scope="module") +def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal["container", "local"]): + if deployment_mode == "container": + return functools.partial(_container_handle, deployment_mode=deployment_mode) + elif deployment_mode == "local": + return functools.partial(_local_handle, deployment_mode=deployment_mode) else: - metafunc.parametrize("prompt,llm", [(_PROMPT_MAPPING["default"], llm)]) + raise ValueError(f"Unknown deployment mode: {deployment_mode}") diff --git a/tests/models/flan_t5/__init__.py b/tests/models/flan_t5/__init__.py deleted file mode 100644 index 3a2faba5..00000000 --- a/tests/models/flan_t5/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/models/flan_t5/test_modeling_flan_t5.py b/tests/models/flan_t5/test_modeling_flan_t5.py deleted file mode 100644 index ac277fcd..00000000 --- a/tests/models/flan_t5/test_modeling_flan_t5.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import typing as t - -import openllm - - -def test_small_flan(prompt: str, llm: openllm.LLM[t.Any, t.Any], qa: bool): - assert llm(prompt) - - -def test_small_runner_flan(prompt: str, llm: openllm.LLMRunner, qa: bool): - assert llm(prompt) diff --git a/tests/models/flan_t5/test_modeling_tf_flan_t5.py b/tests/models/flan_t5/test_modeling_tf_flan_t5.py deleted file mode 100644 index 0fb2953c..00000000 --- a/tests/models/flan_t5/test_modeling_tf_flan_t5.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import typing as t - -import openllm - - -@openllm.tests.require_tf -def test_small_tf_flan(prompt: str, llm: openllm.LLM[t.Any, t.Any], qa: bool): - assert llm(prompt) - - -@openllm.tests.require_tf -def test_small_tf_runner_flan(prompt: str, llm: openllm.LLMRunner, qa: bool): - assert llm(prompt) diff --git a/tests/models/flan_t5_test.py b/tests/models/flan_t5_test.py new file mode 100644 index 00000000..4c229dd0 --- /dev/null +++ b/tests/models/flan_t5_test.py @@ -0,0 +1,60 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +import typing as t + +import pytest + +import openllm + + +if t.TYPE_CHECKING: + import contextlib + + from .conftest import HandleProtocol + from .conftest import ResponseComparator + from .conftest import _Handle + + +model = "flan_t5" +model_id = "google/flan-t5-small" + + +@pytest.fixture(scope="module") +def flan_t5_handle( + handler: HandleProtocol, + deployment_mode: t.Literal["container", "local"], + clean_context: contextlib.ExitStack, +): + with openllm.testing.prepare( + model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context + ) as image_tag: + with handler(model=model, model_id=model_id, image_tag=image_tag) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flan_t5(flan_t5_handle: _Handle): + await flan_t5_handle.health(240) + return flan_t5_handle.client + + +@pytest.mark.asyncio() +async def test_flan_t5(flan_t5: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator): + client = await flan_t5 + response = await client.query("What is the meaning of life?", max_new_tokens=10, top_p=0.9, return_attrs=True) + + assert response.configuration["generation_config"]["max_new_tokens"] == 10 + assert response == response_snapshot diff --git a/tests/models/opt/__init__.py b/tests/models/opt/__init__.py deleted file mode 100644 index 3a2faba5..00000000 --- a/tests/models/opt/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/models/opt/test_modeling_flax_opt.py b/tests/models/opt/test_modeling_flax_opt.py deleted file mode 100644 index 70999357..00000000 --- a/tests/models/opt/test_modeling_flax_opt.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import typing as t - -import openllm - - -@openllm.tests.require_flax -def test_small_opt(prompt: str, llm: openllm.LLM[t.Any, t.Any], qa: bool): - assert llm(prompt) - - -@openllm.tests.require_flax -def test_small_runner_opt(prompt: str, llm: openllm.LLMRunner, qa: bool): - assert llm(prompt) diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py deleted file mode 100644 index 1a1d54fa..00000000 --- a/tests/models/opt/test_modeling_opt.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import typing as t - -import openllm - - -def test_small_opt(prompt: str, llm: openllm.LLM[t.Any, t.Any], qa: bool): - assert llm(prompt) - - -def test_small_runner_opt(prompt: str, llm: openllm.LLMRunner, qa: bool): - assert llm(prompt) diff --git a/tests/models/opt/test_modeling_tf_opt.py b/tests/models/opt/test_modeling_tf_opt.py deleted file mode 100644 index cbbe84bb..00000000 --- a/tests/models/opt/test_modeling_tf_opt.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2023 BentoML Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import typing as t - -import openllm - - -@openllm.tests.require_tf -def test_small_opt(prompt: str, llm: openllm.LLM[t.Any, t.Any], qa: bool): - assert llm(prompt) - - -@openllm.tests.require_tf -def test_small_runner_opt(prompt: str, llm: openllm.LLMRunner, qa: bool): - assert llm(prompt) diff --git a/tests/models/opt_test.py b/tests/models/opt_test.py new file mode 100644 index 00000000..b2cc7401 --- /dev/null +++ b/tests/models/opt_test.py @@ -0,0 +1,59 @@ +# Copyright 2023 BentoML Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations +import typing as t + +import pytest + +import openllm + + +if t.TYPE_CHECKING: + import contextlib + + from .conftest import HandleProtocol + from .conftest import ResponseComparator + from .conftest import _Handle + + +model = "opt" +model_id = "facebook/opt-125m" + + +@pytest.fixture(scope="module") +def opt_125m_handle( + handler: HandleProtocol, + deployment_mode: t.Literal["container", "local"], + clean_context: contextlib.ExitStack, +): + with openllm.testing.prepare( + model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context + ) as image_tag: + with handler(model=model, model_id=model_id, image_tag=image_tag) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def opt_125m(opt_125m_handle: _Handle): + await opt_125m_handle.health(240) + return opt_125m_handle.client + + +@pytest.mark.asyncio() +async def test_opt_125m(opt_125m: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator): + client = await opt_125m + response = await client.query("What is Deep learning?", max_new_tokens=20, return_attrs=True) + + assert response.configuration["generation_config"]["max_new_tokens"] == 20 + assert response == response_snapshot diff --git a/tests/models/flan_t5/test_modeling_flax_flan_t5.py b/tests/models_test.py similarity index 70% rename from tests/models/flan_t5/test_modeling_flax_flan_t5.py rename to tests/models_test.py index bc8409f8..e1f2e462 100644 --- a/tests/models/flan_t5/test_modeling_flax_flan_t5.py +++ b/tests/models_test.py @@ -13,17 +13,20 @@ # limitations under the License. from __future__ import annotations - import typing as t -import openllm + +if t.TYPE_CHECKING: + import openllm -@openllm.tests.require_flax -def test_small_flax_flan(prompt: str, llm: openllm.LLM[t.Any, t.Any], qa: bool): +def test_flan_t5_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]): assert llm(prompt) + assert llm(prompt, temperature=0.8, top_p=0.23) -@openllm.tests.require_flax -def test_small_flax_runner_flan(prompt: str, llm: openllm.LLMRunner, qa: bool): + +def test_opt_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]): assert llm(prompt) + + assert llm(prompt, temperature=0.9, top_k=8) diff --git a/tests/test_package.py b/tests/package_test.py similarity index 86% rename from tests/test_package.py rename to tests/package_test.py index d61d2581..602c69c6 100644 --- a/tests/test_package.py +++ b/tests/package_test.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import typing as t import pytest @@ -60,19 +59,11 @@ def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory): assert len(bento_store.list(bento.tag)) == 1 -@pytest.fixture(name="dockerfile_template", scope="function") +@pytest.fixture(name="dockerfile_template") def fixture_dockerfile_template(tmp_path_factory: pytest.TempPathFactory): file = tmp_path_factory.mktemp("dockerfiles") / "Dockerfile.template" file.write_text( - "\n".join( - [ - "{% extends bento_base_template %}", - "{% block SETUP_BENTO_ENTRYPOINT %}", - "{{ super() }}", - "RUN echo 'sanity from custom dockerfile'", - "{% endblock %}", - ] - ) + "{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}" ) return file @@ -82,6 +73,6 @@ def test_build_with_custom_dockerfile(dockerfile_template: Path): assert openllm.build( "flan-t5", model_id=HF_INTERNAL_T5_TESTING, - overwrite_existing_bento=True, + overwrite=True, dockerfile_template=str(dockerfile_template), ) diff --git a/tests/test_strategies.py b/tests/strategies_test.py similarity index 98% rename from tests/test_strategies.py rename to tests/strategies_test.py index 65201308..8ae288a0 100644 --- a/tests/test_strategies.py +++ b/tests/strategies_test.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations - import typing as t import pytest @@ -47,7 +46,7 @@ def test_cascade_strategy_worker_count(monkeypatch: MonkeyPatch, gpu_type: str): GPURunnable, {gpu_type: 0}, 1, - ) + ).match("No known supported resource available for *") assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 1) == 2 assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 2) == 4 diff --git a/tools/assert-model-table-latest b/tools/assert-model-table-latest index 0bd78414..af19d26e 100755 --- a/tools/assert-model-table-latest +++ b/tools/assert-model-table-latest @@ -18,7 +18,7 @@ with open(os.path.join(ROOT, "README.md"), "r") as f: # NOTE: Currently, we only have one table in README, which is the Model readme. table = [r for r in readme if r.type == "html_block" and r.content.startswith(" None: + """Embed and start an IPython kernel in a given scope. + + If you don't want the kernel to initialize the namespace + from the scope of the surrounding function, + and/or you want to load full IPython configuration, + you probably want `IPython.start_kernel()` instead. + + Parameters + ---------- + module : types.ModuleType, optional + The module to load into IPython globals (default: caller) + local_ns : dict, optional + The namespace to load into IPython user namespace (default: caller) + **kwargs : various, optional + Further keyword args are relayed to the IPKernelApp constructor, + such as `config`, a traitlets :class:`Config` object (see :ref:`configure_start_ipython`), + allowing configuration of the kernel (see :ref:`kernel_options`). Will only have an effect + on the first embed_kernel call for a given process. + """ + ... + +def start_ipython(argv=..., **kwargs): # -> None: + """Launch a normal IPython instance (as opposed to embedded) + + `IPython.embed()` puts a shell in a particular calling scope, + such as a function or method for debugging purposes, + which is often not desirable. + + `start_ipython()` does full, regular IPython initialization, + including loading startup files, configuration, etc. + much of which is skipped by `embed()`. + + This is a public API method, and will survive implementation changes. + + Parameters + ---------- + argv : list or None, optional + If unspecified or None, IPython will parse command-line options from sys.argv. + To prevent any command-line parsing, pass an empty list: `argv=[]`. + user_ns : dict, optional + specify this dictionary to initialize the IPython user namespace with particular values. + **kwargs : various, optional + Any other kwargs will be passed to the Application constructor, + such as `config`, a traitlets :class:`Config` object (see :ref:`configure_start_ipython`), + allowing configuration of the instance (see :ref:`terminal_options`). + """ + ... + +def start_kernel(argv=..., **kwargs): # -> None: + """Launch a normal IPython kernel instance (as opposed to embedded) + + `IPython.embed_kernel()` puts a shell in a particular calling scope, + such as a function or method for debugging purposes, + which is often not desirable. + + `start_kernel()` does full, regular IPython initialization, + including loading startup files, configuration, etc. + much of which is skipped by `embed_kernel()`. + + Parameters + ---------- + argv : list or None, optional + If unspecified or None, IPython will parse command-line options from sys.argv. + To prevent any command-line parsing, pass an empty list: `argv=[]`. + user_ns : dict, optional + specify this dictionary to initialize the IPython user namespace with particular values. + **kwargs : various, optional + Any other kwargs will be passed to the Application constructor, + such as `config`, a traitlets :class:`Config` object (see :ref:`configure_start_ipython`), + allowing configuration of the kernel (see :ref:`kernel_options`). + """ + ... diff --git a/typings/IPython/core/__init__.pyi b/typings/IPython/core/__init__.pyi new file mode 100644 index 00000000..e69de29b diff --git a/typings/IPython/core/getipython.pyi b/typings/IPython/core/getipython.pyi new file mode 100644 index 00000000..8d3436df --- /dev/null +++ b/typings/IPython/core/getipython.pyi @@ -0,0 +1,5 @@ +"""Simple function to call to get the current InteractiveShell instance +""" +from IPython.terminal.interactiveshell import InteractiveShell + +def get_ipython() -> None | InteractiveShell: ... diff --git a/typings/IPython/terminal/__init__.pyi b/typings/IPython/terminal/__init__.pyi new file mode 100644 index 00000000..cea7ef96 --- /dev/null +++ b/typings/IPython/terminal/__init__.pyi @@ -0,0 +1,3 @@ +""" +This type stub file was generated by pyright. +""" diff --git a/typings/IPython/terminal/debugger.pyi b/typings/IPython/terminal/debugger.pyi new file mode 100644 index 00000000..f5ede65a --- /dev/null +++ b/typings/IPython/terminal/debugger.pyi @@ -0,0 +1,41 @@ +""" +This type stub file was generated by pyright. +""" + +from IPython.core.debugger import Pdb + +PTK3 = ... +_use_simple_prompt = ... + +class TerminalPdb(Pdb): + """Standalone IPython debugger.""" + + def __init__(self, *args, pt_session_options=..., **kwargs) -> None: ... + def pt_init(self, pt_session_options=...): # -> None: + """Initialize the prompt session and the prompt loop + and store them in self.pt_app and self.pt_loop. + + Additional keyword arguments for the PromptSession class + can be specified in pt_session_options. + """ + ... + def cmdloop(self, intro=...): # -> None: + """Repeatedly issue a prompt, accept input, parse an initial prefix + off the received input, and dispatch to action methods, passing them + the remainder of the line as argument. + + override the same methods from cmd.Cmd to provide prompt toolkit replacement. + """ + ... + def do_interact(self, arg): ... + +def set_trace(frame=...): # -> None: + """ + Start debugging from `frame`. + + If frame is not specified, debugging starts from caller's frame. + """ + ... + +if __name__ == "__main__": + old_trace_dispatch = ... diff --git a/typings/IPython/terminal/embed.pyi b/typings/IPython/terminal/embed.pyi new file mode 100644 index 00000000..cff3af4c --- /dev/null +++ b/typings/IPython/terminal/embed.pyi @@ -0,0 +1,165 @@ +""" +This type stub file was generated by pyright. +""" + +from typing import Set + +from IPython.core import magic_arguments +from IPython.core.magic import Magics +from IPython.core.magic import line_magic +from IPython.core.magic import magics_class +from IPython.terminal.interactiveshell import TerminalInteractiveShell + +""" +An embedded IPython shell. +""" + +class KillEmbedded(Exception): ... + +KillEmbeded = KillEmbedded + +@magics_class +class EmbeddedMagics(Magics): + @line_magic + @magic_arguments.magic_arguments() + @magic_arguments.argument("-i", "--instance", action="store_true", help="Kill instance instead of call location") + @magic_arguments.argument("-x", "--exit", action="store_true", help="Also exit the current session") + @magic_arguments.argument("-y", "--yes", action="store_true", help="Do not ask confirmation") + def kill_embedded(self, parameter_s=...): # -> None: + """%kill_embedded : deactivate for good the current embedded IPython + + This function (after asking for confirmation) sets an internal flag so + that an embedded IPython will never activate again for the given call + location. This is useful to permanently disable a shell that is being + called inside a loop: once you've figured out what you needed from it, + you may then kill it and the program will then continue to run without + the interactive shell interfering again. + + Kill Instance Option: + + If for some reasons you need to kill the location where the instance + is created and not called, for example if you create a single + instance in one place and debug in many locations, you can use the + ``--instance`` option to kill this specific instance. Like for the + ``call location`` killing an "instance" should work even if it is + recreated within a loop. + + .. note:: + + This was the default behavior before IPython 5.2 + + """ + ... + @line_magic + def exit_raise(self, parameter_s=...): # -> None: + """%exit_raise Make the current embedded kernel exit and raise and exception. + + This function sets an internal flag so that an embedded IPython will + raise a `IPython.terminal.embed.KillEmbedded` Exception on exit, and then exit the current I. This is + useful to permanently exit a loop that create IPython embed instance. + """ + ... + +class _Sentinel: + def __init__(self, repr) -> None: ... + def __repr__(self): ... + +class InteractiveShellEmbed(TerminalInteractiveShell): + dummy_mode = ... + exit_msg = ... + embedded = ... + should_raise = ... + display_banner = ... + exit_msg = ... + term_title = ... + _inactive_locations: Set[str] = ... + @property + def embedded_active(self): ... + @embedded_active.setter + def embedded_active(self, value): ... + def __init__(self, **kw) -> None: ... + def init_sys_modules(self): # -> None: + """ + Explicitly overwrite :mod:`IPython.core.interactiveshell` to do nothing. + """ + ... + def init_magics(self): ... + def __call__( + self, header=..., local_ns=..., module=..., dummy=..., stack_depth=..., compile_flags=..., **kw + ): # -> None: + """Activate the interactive interpreter. + + __call__(self,header='',local_ns=None,module=None,dummy=None) -> Start + the interpreter shell with the given local and global namespaces, and + optionally print a header string at startup. + + The shell can be globally activated/deactivated using the + dummy_mode attribute. This allows you to turn off a shell used + for debugging globally. + + However, *each* time you call the shell you can override the current + state of dummy_mode with the optional keyword parameter 'dummy'. For + example, if you set dummy mode on with IPShell.dummy_mode = True, you + can still have a specific call work by making it as IPShell(dummy=False). + """ + ... + def mainloop(self, local_ns=..., module=..., stack_depth=..., compile_flags=...): # -> None: + """Embeds IPython into a running python program. + + Parameters + ---------- + local_ns, module + Working local namespace (a dict) and module (a module or similar + object). If given as None, they are automatically taken from the scope + where the shell was called, so that program variables become visible. + stack_depth : int + How many levels in the stack to go to looking for namespaces (when + local_ns or module is None). This allows an intermediate caller to + make sure that this function gets the namespace from the intended + level in the stack. By default (0) it will get its locals and globals + from the immediate caller. + compile_flags + A bit field identifying the __future__ features + that are enabled, as passed to the builtin :func:`compile` function. + If given as None, they are automatically taken from the scope where + the shell was called. + + """ + ... + +def embed(*, header=..., compile_flags=..., **kwargs): # -> None: + """Call this to embed IPython at the current point in your program. + + The first invocation of this will create a :class:`terminal.embed.InteractiveShellEmbed` + instance and then call it. Consecutive calls just call the already + created instance. + + If you don't want the kernel to initialize the namespace + from the scope of the surrounding function, + and/or you want to load full IPython configuration, + you probably want `IPython.start_ipython()` instead. + + Here is a simple example:: + + from IPython import embed + a = 10 + b = 20 + embed(header='First time') + c = 30 + d = 40 + embed() + + Parameters + ---------- + + header : str + Optional header string to print at startup. + compile_flags + Passed to the `compile_flags` parameter of :py:meth:`terminal.embed.InteractiveShellEmbed.mainloop()`, + which is called when the :class:`terminal.embed.InteractiveShellEmbed` instance is called. + **kwargs : various, optional + Any other kwargs will be passed to the :class:`terminal.embed.InteractiveShellEmbed` constructor. + Full customization can be done by passing a traitlets :class:`Config` in as the + `config` argument (see :ref:`configure_start_ipython` and :ref:`terminal_options`). + """ + ... diff --git a/typings/IPython/terminal/interactiveshell.pyi b/typings/IPython/terminal/interactiveshell.pyi new file mode 100644 index 00000000..e85c6856 --- /dev/null +++ b/typings/IPython/terminal/interactiveshell.pyi @@ -0,0 +1,125 @@ +""" +This type stub file was generated by pyright. +""" + +from typing import Union as UnionType + +from IPython.core.interactiveshell import InteractiveShell +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.history import History +from prompt_toolkit.shortcuts import PromptSession +from pygments.style import Style +from traitlets import Integer +from traitlets import observe + +from .shortcuts.auto_suggest import NavigableAutoSuggestFromHistory + +"""IPython terminal interface using prompt_toolkit""" +PTK3 = ... + +class _NoStyle(Style): ... + +_style_overrides_light_bg = ... +_style_overrides_linux = ... + +def get_default_editor(): ... + +_use_simple_prompt = ... + +def black_reformat_handler(text_before_cursor): # -> str: + """ + We do not need to protect against error, + this is taken care at a higher level where any reformat error is ignored. + Indeed we may call reformatting on incomplete code. + """ + ... + +def yapf_reformat_handler(text_before_cursor): ... + +class PtkHistoryAdapter(History): + """ + Prompt toolkit has it's own way of handling history, Where it assumes it can + Push/pull from history. + + """ + + def __init__(self, shell) -> None: ... + def append_string(self, string): ... + def load_history_strings(self): ... + def store_string(self, string: str) -> None: ... + +class TerminalInteractiveShell(InteractiveShell): + mime_renderers = ... + space_for_menu = Integer(6, help=...).tag(config=True) + pt_app: UnionType[PromptSession, None] = ... + auto_suggest: UnionType[AutoSuggestFromHistory, NavigableAutoSuggestFromHistory, None] = ... + debugger_history = ... + debugger_history_file = ... + simple_prompt = ... + @property + def debugger_cls(self): ... + + confirm_exit = ... + editing_mode = ... + emacs_bindings_in_vi_insert_mode = ... + modal_cursor = ... + ttimeoutlen = ... + timeoutlen = ... + autoformatter = ... + auto_match = ... + mouse_support = ... + highlighting_style = ... + def refresh_style(self): ... + + highlighting_style_overrides = ... + true_color = ... + editor = ... + prompts_class = ... + prompts = ... + term_title = ... + term_title_format = ... + display_completions = ... + highlight_matching_brackets = ... + extra_open_editor_shortcuts = ... + handle_return = ... + enable_history_search = ... + autosuggestions_provider = ... + shortcuts = ... + prompt_includes_vi_mode = ... + @observe("term_title") + def init_term_title(self, change=...): ... + def restore_term_title(self): ... + def init_display_formatter(self): ... + def init_prompt_toolkit_cli(self): ... + @property + def pt_complete_style(self): ... + @property + def color_depth(self): ... + def prompt_for_code(self): ... + def enable_win_unicode_console(self): ... + def init_io(self): ... + def init_magics(self): ... + def init_alias(self): ... + def __init__(self, *args, **kwargs) -> None: ... + def ask_exit(self): ... + + rl_next_input = ... + def interact(self): ... + def mainloop(self): ... + + _inputhook = ... + def inputhook(self, context): ... + + active_eventloop = ... + def enable_gui(self, gui=...): ... + + system = ... + def auto_rewrite_input(self, cmd): # -> None: + """Overridden from the parent class to use fancy rewriting prompt""" + ... + _prompts_before = ... + def switch_doctest_mode(self, mode): # -> None: + """Switch prompts to classic for %doctest_mode""" + ... + +if __name__ == "__main__": ... diff --git a/typings/IPython/terminal/ipapp.pyi b/typings/IPython/terminal/ipapp.pyi new file mode 100644 index 00000000..2ab005d3 --- /dev/null +++ b/typings/IPython/terminal/ipapp.pyi @@ -0,0 +1,70 @@ +""" +This type stub file was generated by pyright. +""" + +from IPython.core.application import BaseIPythonApplication +from IPython.core.crashhandler import CrashHandler +from IPython.core.shellapp import InteractiveShellApp +from traitlets.config.application import catch_config_error + +""" +The :class:`~traitlets.config.application.Application` object for the command +line :command:`ipython` program. +""" +_examples = ... + +class IPAppCrashHandler(CrashHandler): + """sys.excepthook for IPython itself, leaves a detailed report on disk.""" + + def __init__(self, app) -> None: ... + def make_report(self, traceback): # -> str: + """Return a string containing a crash report.""" + ... + +flags = ... +frontend_flags = ... +addflag = ... +classic_config = ... +aliases = ... + +class LocateIPythonApp(BaseIPythonApplication): + description = ... + subcommands = ... + def start(self): ... + +class TerminalIPythonApp(BaseIPythonApplication, InteractiveShellApp): + name = ... + description = ... + crash_handler_class = IPAppCrashHandler + examples = ... + flags = ... + aliases = ... + classes = ... + interactive_shell_class = ... + subcommands = ... + auto_create = ... + quick = ... + display_banner = ... + force_interact = ... + something_to_run = ... + @catch_config_error + def initialize(self, argv=...): # -> None: + """Do actions after construct, but before starting the app.""" + ... + def init_shell(self): # -> None: + """initialize the InteractiveShell instance""" + ... + def init_banner(self): # -> None: + """optionally display the banner""" + ... + def start(self): ... + +def load_default_config(ipython_dir=...): # -> Instance | Any: + """Load the default config file from the default ipython_dir. + + This is useful for embedded shells. + """ + ... + +launch_new_instance = ... +if __name__ == "__main__": ... diff --git a/typings/IPython/terminal/magics.pyi b/typings/IPython/terminal/magics.pyi new file mode 100644 index 00000000..f651b5ce --- /dev/null +++ b/typings/IPython/terminal/magics.pyi @@ -0,0 +1,119 @@ +""" +This type stub file was generated by pyright. +""" + +import sys + +from IPython.core.magic import Magics +from IPython.core.magic import line_magic +from IPython.core.magic import magics_class +from IPython.testing.skipdoctest import skip_doctest + +"""Extra magics for terminal use.""" + +def get_pasted_lines(sentinel, l_input=..., quiet=...): # -> Generator[Unknown, Any, None]: + """Yield pasted lines until the user enters the given sentinel value.""" + ... + +@magics_class +class TerminalMagics(Magics): + def __init__(self, shell) -> None: ... + def store_or_execute(self, block, name, store_history=...): # -> None: + """Execute a block, or store it in a variable, per the user's request.""" + ... + def preclean_input(self, block): ... + def rerun_pasted(self, name=...): # -> None: + """Rerun a previously pasted command.""" + ... + @line_magic + def autoindent(self, parameter_s=...): # -> None: + """Toggle autoindent on/off (deprecated)""" + ... + @skip_doctest + @line_magic + def cpaste(self, parameter_s=...): # -> None: + """Paste & execute a pre-formatted code block from clipboard. + + You must terminate the block with '--' (two minus-signs) or Ctrl-D + alone on the line. You can also provide your own sentinel with '%paste + -s %%' ('%%' is the new sentinel for this operation). + + The block is dedented prior to execution to enable execution of method + definitions. '>' and '+' characters at the beginning of a line are + ignored, to allow pasting directly from e-mails, diff files and + doctests (the '...' continuation prompt is also stripped). The + executed block is also assigned to variable named 'pasted_block' for + later editing with '%edit pasted_block'. + + You can also pass a variable name as an argument, e.g. '%cpaste foo'. + This assigns the pasted block to variable 'foo' as string, without + dedenting or executing it (preceding >>> and + is still stripped) + + '%cpaste -r' re-executes the block previously entered by cpaste. + '%cpaste -q' suppresses any additional output messages. + + Do not be alarmed by garbled output on Windows (it's a readline bug). + Just press enter and type -- (and press enter again) and the block + will be what was just pasted. + + Shell escapes are not supported (yet). + + See Also + -------- + paste : automatically pull code from clipboard. + + Examples + -------- + :: + + In [8]: %cpaste + Pasting code; enter '--' alone on the line to stop. + :>>> a = ["world!", "Hello"] + :>>> print(" ".join(sorted(a))) + :-- + Hello world! + + :: + In [8]: %cpaste + Pasting code; enter '--' alone on the line to stop. + :>>> %alias_magic t timeit + :>>> %t -n1 pass + :-- + Created `%t` as an alias for `%timeit`. + Created `%%t` as an alias for `%%timeit`. + 354 ns ± 224 ns per loop (mean ± std. dev. of 7 runs, 1 loop each) + """ + ... + @line_magic + def paste(self, parameter_s=...): # -> None: + """Paste & execute a pre-formatted code block from clipboard. + + The text is pulled directly from the clipboard without user + intervention and printed back on the screen before execution (unless + the -q flag is given to force quiet mode). + + The block is dedented prior to execution to enable execution of method + definitions. '>' and '+' characters at the beginning of a line are + ignored, to allow pasting directly from e-mails, diff files and + doctests (the '...' continuation prompt is also stripped). The + executed block is also assigned to variable named 'pasted_block' for + later editing with '%edit pasted_block'. + + You can also pass a variable name as an argument, e.g. '%paste foo'. + This assigns the pasted block to variable 'foo' as string, without + executing it (preceding >>> and + is still stripped). + + Options: + + -r: re-executes the block previously entered by cpaste. + + -q: quiet mode: do not echo the pasted text back to the terminal. + + IPython statements (magics, shell escapes) are not supported (yet). + + See Also + -------- + cpaste : manually paste code into terminal until you mark its end. + """ + ... + if sys.platform == "win32": ... diff --git a/typings/IPython/terminal/prompts.pyi b/typings/IPython/terminal/prompts.pyi new file mode 100644 index 00000000..acc02295 --- /dev/null +++ b/typings/IPython/terminal/prompts.pyi @@ -0,0 +1,27 @@ +""" +This type stub file was generated by pyright. +""" + +from IPython.core.displayhook import DisplayHook + +"""Terminal input and output prompts.""" + +class Prompts: + def __init__(self, shell) -> None: ... + def vi_mode(self): ... + def in_prompt_tokens(self): ... + def continuation_prompt_tokens(self, width=...): ... + def rewrite_prompt_tokens(self): ... + def out_prompt_tokens(self): ... + +class ClassicPrompts(Prompts): + def in_prompt_tokens(self): ... + def continuation_prompt_tokens(self, width=...): ... + def rewrite_prompt_tokens(self): ... + def out_prompt_tokens(self): ... + +class RichPromptDisplayHook(DisplayHook): + """Subclass of base display hook using coloured prompt""" + + def write_output_prompt(self): ... + def write_format_data(self, format_dict, md_dict=...) -> None: ... diff --git a/typings/IPython/terminal/pt_inputhooks/__init__.pyi b/typings/IPython/terminal/pt_inputhooks/__init__.pyi new file mode 100644 index 00000000..bb03e340 --- /dev/null +++ b/typings/IPython/terminal/pt_inputhooks/__init__.pyi @@ -0,0 +1,23 @@ +""" +This type stub file was generated by pyright. +""" + +import importlib +import os + +aliases = ... +backends = ... +registered = ... + +def register(name, inputhook): # -> None: + """Register the function *inputhook* as an event loop integration.""" + ... + +class UnknownBackend(KeyError): + def __init__(self, name) -> None: ... + +def set_qt_api(gui): # -> str | None: + """Sets the `QT_API` environment variable if it isn't already set.""" + ... + +def get_inputhook_name_and_func(gui): ... diff --git a/typings/IPython/terminal/ptutils.pyi b/typings/IPython/terminal/ptutils.pyi new file mode 100644 index 00000000..362970af --- /dev/null +++ b/typings/IPython/terminal/ptutils.pyi @@ -0,0 +1,29 @@ +""" +This type stub file was generated by pyright. +""" + +from prompt_toolkit.completion import Completer +from prompt_toolkit.lexers import Lexer + +"""prompt-toolkit utilities + +Everything in this module is a private API, +not to be used outside IPython. +""" +_completion_sentinel = ... + +class IPythonPTCompleter(Completer): + """Adaptor to provide IPython completions to prompt_toolkit""" + + def __init__(self, ipy_completer=..., shell=...) -> None: ... + @property + def ipy_completer(self): ... + def get_completions(self, document, complete_event): ... + +class IPythonPTLexer(Lexer): + """ + Wrapper around PythonLexer and BashLexer. + """ + + def __init__(self) -> None: ... + def lex_document(self, document): ... diff --git a/typings/IPython/terminal/shortcuts/__init__.pyi b/typings/IPython/terminal/shortcuts/__init__.pyi new file mode 100644 index 00000000..81248455 --- /dev/null +++ b/typings/IPython/terminal/shortcuts/__init__.pyi @@ -0,0 +1,149 @@ +""" +This type stub file was generated by pyright. +""" + +import os +import signal +import sys +import warnings +from dataclasses import dataclass +from typing import Any +from typing import Callable +from typing import List +from typing import Optional + +from IPython.core.getipython import get_ipython +from IPython.terminal.shortcuts import auto_match as match +from IPython.terminal.shortcuts import auto_suggest +from IPython.terminal.shortcuts.filters import filter_from_string +from IPython.utils.decorators import undoc +from prompt_toolkit.application.current import get_app +from prompt_toolkit.enums import DEFAULT_BUFFER +from prompt_toolkit.filters import Condition +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.key_binding.bindings import named_commands as nc +from prompt_toolkit.key_binding.bindings.completion import display_completions_like_readline +from prompt_toolkit.key_binding.key_processor import KeyPressEvent +from prompt_toolkit.key_binding.vi_state import InputMode +from prompt_toolkit.key_binding.vi_state import ViState + +""" +Module to define and register Terminal IPython shortcuts with +:mod:`prompt_toolkit` +""" +__all__ = ["create_ipython_shortcuts"] + +@dataclass +class BaseBinding: + command: Callable[[KeyPressEvent], Any] + keys: List[str] + +@dataclass +class RuntimeBinding(BaseBinding): + filter: Condition + +@dataclass +class Binding(BaseBinding): + condition: Optional[str] = ... + def __post_init__(self): ... + +def create_identifier(handler: Callable): ... + +AUTO_MATCH_BINDINGS = ... +AUTO_SUGGEST_BINDINGS = ... +SIMPLE_CONTROL_BINDINGS = ... +ALT_AND_COMOBO_CONTROL_BINDINGS = ... + +def add_binding(bindings: KeyBindings, binding: Binding): ... +def create_ipython_shortcuts(shell, skip=...) -> KeyBindings: + """Set up the prompt_toolkit keyboard shortcuts for IPython. + + Parameters + ---------- + shell: InteractiveShell + The current IPython shell Instance + skip: List[Binding] + Bindings to skip. + + Returns + ------- + KeyBindings + the keybinding instance for prompt toolkit. + + """ + ... + +def reformat_and_execute(event): # -> None: + """Reformat code and execute it""" + ... + +def reformat_text_before_cursor(buffer, document, shell): ... +def handle_return_or_newline_or_execute(event): ... +def newline_or_execute_outer(shell): ... +def previous_history_or_previous_completion(event): # -> None: + """ + Control-P in vi edit mode on readline is history next, unlike default prompt toolkit. + + If completer is open this still select previous completion. + """ + ... + +def next_history_or_next_completion(event): # -> None: + """ + Control-N in vi edit mode on readline is history previous, unlike default prompt toolkit. + + If completer is open this still select next completion. + """ + ... + +def dismiss_completion(event): # -> None: + """Dismiss completion""" + ... + +def reset_buffer(event): # -> None: + """Reset buffer""" + ... + +def reset_search_buffer(event): # -> None: + """Reset search buffer""" + ... + +def suspend_to_bg(event): # -> None: + """Suspend to background""" + ... + +def quit(event): # -> None: + """ + Quit application with ``SIGQUIT`` if supported or ``sys.exit`` otherwise. + + On platforms that support SIGQUIT, send SIGQUIT to the current process. + On other platforms, just exit the process with a message. + """ + ... + +def indent_buffer(event): # -> None: + """Indent buffer""" + ... + +def newline_autoindent(event): # -> None: + """Insert a newline after the cursor indented appropriately. + + Fancier version of former ``newline_with_copy_margin`` which should + compute the correct indentation of the inserted line. That is to say, indent + by 4 extra space after a function definition, class definition, context + manager... And dedent by 4 space after ``pass``, ``return``, ``raise ...``. + """ + ... + +def open_input_in_editor(event): # -> None: + """Open code from input in external editor""" + ... + +if sys.platform == "win32": ... +else: + @undoc + def win_paste(event): # -> None: + """Stub used on other platforms""" + ... + +KEY_BINDINGS = ... diff --git a/typings/IPython/terminal/shortcuts/auto_match.pyi b/typings/IPython/terminal/shortcuts/auto_match.pyi new file mode 100644 index 00000000..2cf83b33 --- /dev/null +++ b/typings/IPython/terminal/shortcuts/auto_match.pyi @@ -0,0 +1,65 @@ +""" +This type stub file was generated by pyright. +""" + +from prompt_toolkit.key_binding import KeyPressEvent + +""" +Utilities function for keybinding with prompt toolkit. + +This will be bound to specific key press and filter modes, +like whether we are in edit mode, and whether the completer is open. +""" + +def parenthesis(event: KeyPressEvent): # -> None: + """Auto-close parenthesis""" + ... + +def brackets(event: KeyPressEvent): # -> None: + """Auto-close brackets""" + ... + +def braces(event: KeyPressEvent): # -> None: + """Auto-close braces""" + ... + +def double_quote(event: KeyPressEvent): # -> None: + """Auto-close double quotes""" + ... + +def single_quote(event: KeyPressEvent): # -> None: + """Auto-close single quotes""" + ... + +def docstring_double_quotes(event: KeyPressEvent): # -> None: + """Auto-close docstring (double quotes)""" + ... + +def docstring_single_quotes(event: KeyPressEvent): # -> None: + """Auto-close docstring (single quotes)""" + ... + +def raw_string_parenthesis(event: KeyPressEvent): # -> None: + """Auto-close parenthesis in raw strings""" + ... + +def raw_string_bracket(event: KeyPressEvent): # -> None: + """Auto-close bracker in raw strings""" + ... + +def raw_string_braces(event: KeyPressEvent): # -> None: + """Auto-close braces in raw strings""" + ... + +def skip_over(event: KeyPressEvent): # -> None: + """Skip over automatically added parenthesis/quote. + + (rather than adding another parenthesis/quote)""" + ... + +def delete_pair(event: KeyPressEvent): # -> None: + """Delete auto-closed parenthesis""" + ... + +auto_match_parens = ... +auto_match_parens_raw_string = ... diff --git a/typings/IPython/terminal/shortcuts/auto_suggest.pyi b/typings/IPython/terminal/shortcuts/auto_suggest.pyi new file mode 100644 index 00000000..8adc0d54 --- /dev/null +++ b/typings/IPython/terminal/shortcuts/auto_suggest.pyi @@ -0,0 +1,101 @@ +""" +This type stub file was generated by pyright. +""" + +from typing import Optional +from typing import Union + +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.auto_suggest import Suggestion +from prompt_toolkit.buffer import Buffer +from prompt_toolkit.document import Document +from prompt_toolkit.history import History +from prompt_toolkit.key_binding import KeyPressEvent +from prompt_toolkit.layout.processors import Processor +from prompt_toolkit.layout.processors import Transformation +from prompt_toolkit.layout.processors import TransformationInput +from prompt_toolkit.shortcuts import PromptSession + +class AppendAutoSuggestionInAnyLine(Processor): + """ + Append the auto suggestion to lines other than the last (appending to the + last line is natively supported by the prompt toolkit). + """ + + def __init__(self, style: str = ...) -> None: ... + def apply_transformation(self, ti: TransformationInput) -> Transformation: ... + +class NavigableAutoSuggestFromHistory(AutoSuggestFromHistory): + """ + A subclass of AutoSuggestFromHistory that allow navigation to next/previous + suggestion from history. To do so it remembers the current position, but it + state need to carefully be cleared on the right events. + """ + + def __init__(self) -> None: ... + def reset_history_position(self, _: Buffer): ... + def disconnect(self): ... + def connect(self, pt_app: PromptSession): ... + def get_suggestion(self, buffer: Buffer, document: Document) -> Optional[Suggestion]: ... + def up(self, query: str, other_than: str, history: History) -> None: ... + def down(self, query: str, other_than: str, history: History) -> None: ... + +def accept_or_jump_to_end(event: KeyPressEvent): # -> None: + """Apply autosuggestion or jump to end of line.""" + ... + +def accept(event: KeyPressEvent): # -> None: + """Accept autosuggestion""" + ... + +def discard(event: KeyPressEvent): # -> None: + """Discard autosuggestion""" + ... + +def accept_word(event: KeyPressEvent): # -> None: + """Fill partial autosuggestion by word""" + ... + +def accept_character(event: KeyPressEvent): # -> None: + """Fill partial autosuggestion by character""" + ... + +def accept_and_keep_cursor(event: KeyPressEvent): # -> None: + """Accept autosuggestion and keep cursor in place""" + ... + +def accept_and_move_cursor_left(event: KeyPressEvent): # -> None: + """Accept autosuggestion and move cursor left in place""" + ... + +def backspace_and_resume_hint(event: KeyPressEvent): # -> None: + """Resume autosuggestions after deleting last character""" + ... + +def resume_hinting(event: KeyPressEvent): # -> None: + """Resume autosuggestions""" + ... + +def up_and_update_hint(event: KeyPressEvent): # -> None: + """Go up and update hint""" + ... + +def down_and_update_hint(event: KeyPressEvent): # -> None: + """Go down and update hint""" + ... + +def accept_token(event: KeyPressEvent): # -> None: + """Fill partial autosuggestion by token""" + ... + +Provider = Union[AutoSuggestFromHistory, NavigableAutoSuggestFromHistory, None] + +def swap_autosuggestion_up(event: KeyPressEvent): # -> None: + """Get next autosuggestion from history.""" + ... + +def swap_autosuggestion_down(event: KeyPressEvent): # -> None: + """Get previous autosuggestion from history.""" + ... + +def __getattr__(key): ... diff --git a/typings/IPython/terminal/shortcuts/filters.pyi b/typings/IPython/terminal/shortcuts/filters.pyi new file mode 100644 index 00000000..17c3ed8e --- /dev/null +++ b/typings/IPython/terminal/shortcuts/filters.pyi @@ -0,0 +1,81 @@ +""" +This type stub file was generated by pyright. +""" + +import ast +from typing import Callable +from typing import Dict +from typing import Union + +from IPython.utils.decorators import undoc +from prompt_toolkit.filters import Condition +from prompt_toolkit.filters import Filter +from prompt_toolkit.key_binding import KeyPressEvent +from prompt_toolkit.layout.layout import FocusableElement + +""" +Filters restricting scope of IPython Terminal shortcuts. +""" + +@undoc +@Condition +def cursor_in_leading_ws(): ... +def has_focus(value: FocusableElement): # -> Condition: + """Wrapper around has_focus adding a nice `__name__` to tester function""" + ... + +@undoc +@Condition +def has_line_below() -> bool: ... +@undoc +@Condition +def is_cursor_at_the_end_of_line() -> bool: ... +@undoc +@Condition +def has_line_above() -> bool: ... +@Condition +def ebivim(): ... +@Condition +def supports_suspend(): ... +@Condition +def auto_match(): ... +def all_quotes_paired(quote, buf): ... + +_preceding_text_cache: Dict[Union[str, Callable], Condition] = ... +_following_text_cache: Dict[Union[str, Callable], Condition] = ... + +def preceding_text(pattern: Union[str, Callable]): ... +def following_text(pattern): ... +@Condition +def not_inside_unclosed_string(): ... +@Condition +def navigable_suggestions(): ... +@Condition +def readline_like_completions(): ... +@Condition +def is_windows_os(): ... + +class PassThrough(Filter): + """A filter allowing to implement pass-through behaviour of keybindings. + + Prompt toolkit key processor dispatches only one event per binding match, + which means that adding a new shortcut will suppress the old shortcut + if the keybindings are the same (unless one is filtered out). + + To stop a shortcut binding from suppressing other shortcuts: + - add the `pass_through` filter to list of filter, and + - call `pass_through.reply(event)` in the shortcut handler. + """ + + def __init__(self) -> None: ... + def reply(self, event: KeyPressEvent): ... + def __call__(self): ... + +pass_through = ... +default_buffer_focused = ... +KEYBINDING_FILTERS = ... + +def eval_node(node: Union[ast.AST, None]): ... +def filter_from_string(code: str): ... + +__all__ = ["KEYBINDING_FILTERS", "filter_from_string"] diff --git a/typings/attr/__init__.pyi b/typings/attr/__init__.pyi index fdfbec82..13be27fe 100644 --- a/typings/attr/__init__.pyi +++ b/typings/attr/__init__.pyi @@ -1,7 +1,4 @@ -from __future__ import annotations - import enum -import sys from typing import Any from typing import Callable from typing import Dict @@ -18,6 +15,7 @@ from typing import Type from typing import TypeGuard from typing import TypeVar from typing import Union +from typing import dataclass_transform from typing import overload from . import converters as converters @@ -43,14 +41,14 @@ _T = TypeVar("_T") _C = TypeVar("_C", bound=type) _P = ParamSpec("_P") _EqOrderType = Union[bool, Callable[[Any], Any]] -_ValidatorType = Callable[[Any, "Attribute[_T]", _T], Any] +_ValidatorType = Callable[[Any, Attribute[_T], _T], Any] _ConverterType = Callable[[Any], Any] -_FilterType = Callable[["Attribute[_T]", _T], bool] +_FilterType = Callable[[Attribute[_T], _T], bool] _ReprType = Callable[[Any], str] _ReprArgType = Union[bool, _ReprType] -_OnSetAttrType = Callable[[Any, "Attribute[Any]", Any], Any] +_OnSetAttrType = Callable[[Any, Attribute[Any], Any], Any] _OnSetAttrArgType = Union[_OnSetAttrType, List[_OnSetAttrType], setters._NoOpType] -_FieldTransformer = Callable[[type, List["Attribute[Any]"]], List["Attribute[Any]"]] +_FieldTransformer = Callable[[type, List[Attribute[Any]]], List[Attribute[Any]]] _ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]] class ReprProtocol(Protocol): @@ -64,22 +62,13 @@ class _Nothing(enum.Enum): NOTHING = ... NOTHING = ... -if sys.version_info >= (3, 8): - @overload - def Factory(factory: Callable[[], _T]) -> _T: ... - @overload - def Factory(factory: Callable[[Any], _T], takes_self: Literal[True]) -> _T: ... - @overload - def Factory(factory: Callable[[], _T], takes_self: Literal[False]) -> _T: ... -def __dataclass_transform__( - *, - eq_default: bool = ..., - order_default: bool = ..., - kw_only_default: bool = ..., - frozen_default: bool = ..., - field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = ..., -) -> Callable[[_T], _T]: ... +@overload +def Factory(factory: Callable[[], _T]) -> _T: ... +@overload +def Factory(factory: Callable[[Any], _T], takes_self: Literal[True]) -> _T: ... +@overload +def Factory(factory: Callable[[], _T], takes_self: Literal[False]) -> _T: ... class _CountingAttr(Generic[_T]): counter: int @@ -90,36 +79,60 @@ class _CountingAttr(Generic[_T]): eq_key: str order: _EqOrderType order_key: str - hash: Optional[bool] + hash: bool | None init: bool - converter: Optional[_ConverterType] - metadata: Dict[Any, Any] - _validator: Optional[_ValidatorType[_T]] - type: Optional[Type[_T]] + converter: _ConverterType | None + metadata: dict[Any, Any] + _validator: _ValidatorType[_T] | None + type: type[_T] | None kw_only: bool on_setattr: _OnSetAttrType - alias: Optional[str] + alias: str | None class Attribute(Generic[_T]): name: str - default: Optional[_T] - validator: Optional[_ValidatorType[_T]] + default: _T | None + validator: _ValidatorType[_T] | None repr: _ReprArgType cmp: _EqOrderType eq: _EqOrderType order: _EqOrderType - hash: Optional[bool] + hash: bool | None init: bool - converter: Optional[_ConverterType] - metadata: Dict[Any, Any] - type: Optional[Type[_T]] + converter: _ConverterType | None + metadata: dict[Any, Any] + type: type[_T] | None kw_only: bool on_setattr: _OnSetAttrType - alias: Optional[str] + alias: str | None def evolve(self, **changes: Any) -> Attribute[Any]: ... @classmethod def from_counting_attr(cls, name: str, ca: _CountingAttr[_T], type: type[Any] | None = None) -> Attribute[_T]: ... +# NOTE: We had several choices for the annotation to use for type arg: +# 1) Type[_T] +# - Pros: Handles simple cases correctly +# - Cons: Might produce less informative errors in the case of conflicting +# TypeVars e.g. `attr.ib(default='bad', type=int)` +# 2) Callable[..., _T] +# - Pros: Better error messages than #1 for conflicting TypeVars +# - Cons: Terrible error messages for validator checks. +# e.g. attr.ib(type=int, validator=validate_str) +# -> error: Cannot infer function type argument +# 3) type (and do all of the work in the mypy plugin) +# - Pros: Simple here, and we could customize the plugin with our own errors. +# - Cons: Would need to write mypy plugin code to handle all the cases. +# We chose option #1. + +# `attr` lies about its return type to make the following possible: +# attr() -> Any +# attr(8) -> int +# attr(validator=) -> Whatever the callable expects. +# This makes this type of assignments possible: +# x: int = attr(8) +# +# This form catches explicit None or no default but with no other arguments +# returns Any. @overload def attrib( default: None = ..., @@ -138,6 +151,9 @@ def attrib( on_setattr: Optional[_OnSetAttrArgType] = ..., alias: Optional[str] = ..., ) -> Any: ... + +# This form catches an explicit None or no default and infers the type from the +# other arguments. @overload def attrib( default: None = ..., @@ -156,6 +172,8 @@ def attrib( on_setattr: Optional[_OnSetAttrArgType] = ..., alias: Optional[str] = ..., ) -> _T: ... + +# This form catches an explicit default argument. @overload def attrib( default: _T, @@ -174,6 +192,8 @@ def attrib( on_setattr: Optional[_OnSetAttrArgType] = ..., alias: Optional[str] = ..., ) -> _T: ... + +# This form covers type=non-Type: e.g. forward references (str), Any @overload def attrib( default: Optional[_T] = ..., @@ -210,6 +230,9 @@ def field( alias: Optional[str] = ..., type: Optional[type] = ..., ) -> Any: ... + +# This form catches an explicit None or no default and infers the type from the +# other arguments. @overload def field( *, @@ -228,6 +251,8 @@ def field( alias: Optional[str] = ..., type: Optional[type] = ..., ) -> _T: ... + +# This form catches an explicit default argument. @overload def field( *, @@ -246,6 +271,8 @@ def field( alias: Optional[str] = ..., type: Optional[type] = ..., ) -> _T: ... + +# This form covers type=non-Type: e.g. forward references (str), Any @overload def field( *, @@ -265,7 +292,7 @@ def field( type: Optional[type] = ..., ) -> Any: ... @overload -@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field)) +@dataclass_transform(order_default=True, field_specifiers=(attrib, field)) def attrs( maybe_cls: _C, these: Optional[Dict[str, Any]] = ..., @@ -293,7 +320,7 @@ def attrs( unsafe_hash: Optional[bool] = ..., ) -> _C: ... @overload -@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field)) +@dataclass_transform(order_default=True, field_specifiers=(attrib, field)) def attrs( maybe_cls: None = ..., these: Optional[Dict[str, Any]] = ..., @@ -321,7 +348,7 @@ def attrs( unsafe_hash: Optional[bool] = ..., ) -> Callable[[_C], _C]: ... @overload -@__dataclass_transform__(field_descriptors=(attrib, field)) +@dataclass_transform(field_specifiers=(attrib, field)) def define( maybe_cls: _C, *, @@ -347,7 +374,7 @@ def define( match_args: bool = ..., ) -> _C: ... @overload -@__dataclass_transform__(field_descriptors=(attrib, field)) +@dataclass_transform(field_specifiers=(attrib, field)) def define( maybe_cls: None = ..., *, @@ -373,10 +400,10 @@ def define( match_args: bool = ..., ) -> Callable[[_C], _C]: ... -mutable = ... +mutable = define @overload -@__dataclass_transform__(frozen_default=True, field_descriptors=(attrib, field)) +@dataclass_transform(frozen_default=True, field_specifiers=(attrib, field)) def frozen( maybe_cls: _C, *, @@ -402,7 +429,7 @@ def frozen( match_args: bool = ..., ) -> _C: ... @overload -@__dataclass_transform__(frozen_default=True, field_descriptors=(attrib, field)) +@dataclass_transform(frozen_default=True, field_specifiers=(attrib, field)) def frozen( maybe_cls: None = ..., *, @@ -437,6 +464,10 @@ def resolve_types( attribs: Optional[List[Attribute[Any]]] = ..., include_extras: bool = ..., ) -> _A: ... + +# TODO: add support for returning a proper attrs class from the mypy plugin +# we use Any instead of _CountingAttr so that e.g. `make_class('Foo', +# [attr.ib()])` is valid def make_class( name: str, attrs: Union[List[str], Tuple[str, ...], Dict[str, Any]], @@ -460,6 +491,15 @@ def make_class( on_setattr: Optional[_OnSetAttrArgType] = ..., field_transformer: Optional[_FieldTransformer] = ..., ) -> type: ... + +# _funcs -- + +# TODO: add support for returning TypedDict from the mypy plugin +# FIXME: asdict/astuple do not honor their factory args. Waiting on one of +# these: +# https://github.com/python/mypy/issues/4236 +# https://github.com/python/typing/issues/253 +# XXX: remember to fix attrs.asdict/astuple too! def asdict( inst: AttrsInstance, recurse: bool = ..., @@ -469,6 +509,8 @@ def asdict( value_serializer: Optional[Callable[[type, Attribute[Any], Any], Any]] = ..., tuple_keys: Optional[bool] = ..., ) -> Dict[str, Any]: ... + +# TODO: add support for returning NamedTuple from the mypy plugin def astuple( inst: AttrsInstance, recurse: bool = ..., @@ -479,12 +521,17 @@ def astuple( def has(cls: type) -> TypeGuard[Type[AttrsInstance]]: ... def assoc(inst: _T, **changes: Any) -> _T: ... def evolve(inst: _T, **changes: Any) -> _T: ... + +# _config -- + def set_run_validators(run: bool) -> None: ... def get_run_validators() -> bool: ... -attributes = ... -attr = ... -dataclass = ... +# aliases -- + +s = attributes = attrs +ib = attr = attrib +dataclass = attrs # Technically, partial(attrs, auto_attribs=True) ;) def _make_init( cls: type[AttrsInstance], @@ -494,12 +541,11 @@ def _make_init( frozen: bool, slots: bool, cache_hash: bool, - base_attr_map: dict[str, Any], + base_attr_map: dict[Any, Any], is_exc: bool, cls_on_setattr: Any, attrs_init: bool, ) -> Callable[_P, Any]: ... -def _make_method(name: str, script: str, filename: str, globs: dict[str, Any]) -> Callable[..., Any]: ... def _make_repr(attrs: tuple[Attribute[Any]], ns: str | None, cls: AttrsInstance) -> ReprProtocol: ... def _transform_attrs( cls: type[AttrsInstance], @@ -507,5 +553,5 @@ def _transform_attrs( auto_attribs: bool, kw_only: bool, collect_by_mro: bool, - field_transformer: Optional[_FieldTransformer], -) -> tuple[tuple[attr.Attribute[Any], ...], tuple[attr.Attribute[Any], ...], dict[attr.Attribute[Any], type[Any]]]: ... + field_transformer: _FieldTransformer | None, +) -> tuple[tuple[Attribute[Any], ...], tuple[Attribute[Any], ...], dict[Attribute[Any], type[Any]]]: ... diff --git a/typings/attr/_typing_compat.pyi b/typings/attr/_typing_compat.pyi index 7b0615e1..b6c68a0c 100644 --- a/typings/attr/_typing_compat.pyi +++ b/typings/attr/_typing_compat.pyi @@ -2,12 +2,15 @@ from typing import Any from typing import ClassVar from typing import Protocol -MYPY = False +# MYPY is a special constant in mypy which works the same way as `TYPE_CHECKING`. +MYPY: bool = False if MYPY: + # A protocol to be able to statically accept an attrs class. class AttrsInstance_(Protocol): __attrs_attrs__: ClassVar[Any] - ... else: + # For type checkers without plug-in support use an empty protocol that + # will (hopefully) be combined into a union. class AttrsInstance_(Protocol): ... diff --git a/typings/click_option_group/_core.pyi b/typings/click_option_group/_core.pyi index 58a659d0..b423f216 100644 --- a/typings/click_option_group/_core.pyi +++ b/typings/click_option_group/_core.pyi @@ -1,5 +1,5 @@ -from collections.abc import Callable from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import Mapping @@ -7,104 +7,49 @@ from typing import Optional from typing import Sequence from typing import Set from typing import Tuple +from typing import TypeAlias +from typing import TypeVar from typing import Union import click -FC = Union[Callable[..., Any], click.Command] +_R = TypeVar("_R") +_T = TypeVar("_T") +AnyCallable = Callable[..., Any] +Decorator: TypeAlias = Callable[[_T], _T] +_FC = TypeVar("_FC", bound=Union[AnyCallable, click.Command]) class GroupedOption(click.Option): - """Represents grouped (related) optional values - - The class should be used only with `OptionGroup` class for creating grouped options. - - :param param_decls: option declaration tuple - :param group: `OptionGroup` instance (the group for this option) - :param attrs: additional option attributes - """ - def __init__(self, param_decls: Optional[Sequence[str]] = ..., *, group: OptionGroup, **attrs: Any) -> None: ... @property - def group(self) -> OptionGroup: - """Returns the reference to the group for this option - - :return: `OptionGroup` the group instance for this option - """ - ... + def group(self) -> OptionGroup: ... def handle_parse_result( self, ctx: click.Context, opts: Mapping[str, Any], args: List[str] ) -> Tuple[Any, List[str]]: ... def get_help_record(self, ctx: click.Context) -> Optional[Tuple[str, str]]: ... class _GroupTitleFakeOption(click.Option): - """The helper `Option` class to display option group title in help""" - def __init__(self, param_decls: Optional[Sequence[str]] = ..., *, group: OptionGroup, **attrs: Any) -> None: ... def get_help_record(self, ctx: click.Context) -> Optional[Tuple[str, str]]: ... class OptionGroup: - """Option group manages grouped (related) options - - The class is used for creating the groups of options. The class can de used as based class to implement - specific behavior for grouped options. - - :param name: the group name. If it is not set the default group name will be used - :param help: the group help text or None - """ - def __init__(self, name: Optional[str] = ..., *, hidden: bool = ..., help: Optional[str] = ...) -> None: ... @property - def name(self) -> str: - """Returns the group name or empty string if it was not set - - :return: group name - """ - ... + def name(self) -> str: ... @property - def help(self) -> str: - """Returns the group help or empty string if it was not set - - :return: group help - """ - ... + def help(self) -> str: ... @property - def name_extra(self) -> List[str]: - """Returns extra name attributes for the group""" - ... + def name_extra(self) -> List[str]: ... @property - def forbidden_option_attrs(self) -> List[str]: - """Returns the list of forbidden option attributes for the group""" - ... - def get_help_record(self, ctx: click.Context) -> Optional[Tuple[str, str]]: - """Returns the help record for the group - - :param ctx: Click Context object - :return: the tuple of two fileds: `(name, help)` - """ - ... - def option(self, *param_decls: str, **attrs: Any) -> Callable: - """Decorator attaches an grouped option to the command - - The decorator is used for adding options to the group and to the Click-command - """ - ... - def get_options(self, ctx: click.Context) -> Dict[str, GroupedOption]: - """Returns the dictionary with group options""" - ... - def get_option_names(self, ctx: click.Context) -> List[str]: - """Returns the list with option names ordered by addition in the group""" - ... + def forbidden_option_attrs(self) -> List[str]: ... + def get_help_record(self, ctx: click.Context) -> Optional[Tuple[str, str]]: ... + def option(self, *param_decls: Any, **attrs: Any) -> Decorator[_FC]: ... + def get_options(self, ctx: click.Context) -> Dict[str, GroupedOption]: ... + def get_option_names(self, ctx: click.Context) -> List[str]: ... def get_error_hint(self, ctx: click.Context, option_names: Optional[Set[str]] = ...) -> str: ... - def handle_parse_result(self, option: GroupedOption, ctx: click.Context, opts: Mapping[str, Any]) -> None: - """The method should be used for adding specific behavior and relation for options in the group""" - ... + def handle_parse_result(self, option: GroupedOption, ctx: click.Context, opts: Mapping[str, Any]) -> None: ... class RequiredAnyOptionGroup(OptionGroup): - """Option group with required any options of this group - - `RequiredAnyOptionGroup` defines the behavior: At least one option from the group must be set. - """ - @property def forbidden_option_attrs(self) -> List[str]: ... @property @@ -112,11 +57,6 @@ class RequiredAnyOptionGroup(OptionGroup): def handle_parse_result(self, option: GroupedOption, ctx: click.Context, opts: Mapping[str, Any]) -> None: ... class RequiredAllOptionGroup(OptionGroup): - """Option group with required all options of this group - - `RequiredAllOptionGroup` defines the behavior: All options from the group must be set. - """ - @property def forbidden_option_attrs(self) -> List[str]: ... @property @@ -124,12 +64,6 @@ class RequiredAllOptionGroup(OptionGroup): def handle_parse_result(self, option: GroupedOption, ctx: click.Context, opts: Mapping[str, Any]) -> None: ... class MutuallyExclusiveOptionGroup(OptionGroup): - """Option group with mutually exclusive behavior for grouped options - - `MutuallyExclusiveOptionGroup` defines the behavior: - - Only one or none option from the group must be set - """ - @property def forbidden_option_attrs(self) -> List[str]: ... @property @@ -137,23 +71,11 @@ class MutuallyExclusiveOptionGroup(OptionGroup): def handle_parse_result(self, option: GroupedOption, ctx: click.Context, opts: Mapping[str, Any]) -> None: ... class RequiredMutuallyExclusiveOptionGroup(MutuallyExclusiveOptionGroup): - """Option group with required and mutually exclusive behavior for grouped options - - `RequiredMutuallyExclusiveOptionGroup` defines the behavior: - - Only one required option from the group must be set - """ - @property def name_extra(self) -> List[str]: ... def handle_parse_result(self, option: GroupedOption, ctx: click.Context, opts: Mapping[str, Any]) -> None: ... class AllOptionGroup(OptionGroup): - """Option group with required all/none options of this group - - `AllOptionGroup` defines the behavior: - - All options from the group must be set or None must be set - """ - @property def forbidden_option_attrs(self) -> List[str]: ... @property diff --git a/typings/click_option_group/_decorators.pyi b/typings/click_option_group/_decorators.pyi index 86ea6f19..41270532 100644 --- a/typings/click_option_group/_decorators.pyi +++ b/typings/click_option_group/_decorators.pyi @@ -1,112 +1,65 @@ from typing import Any from typing import Callable from typing import Dict -from typing import Generic from typing import NamedTuple from typing import Optional -from typing import ParamSpec -from typing import Protocol from typing import Tuple from typing import Type from typing import TypeVar +from typing import Union +from typing import overload import click -from ._core import FC +from ._core import _FC +from ._core import AnyCallable +from ._core import Decorator from ._core import OptionGroup -P = ParamSpec("P") -O_co = TypeVar("O_co", covariant=True) - -F = Callable[P, O_co] - class OptionStackItem(NamedTuple): param_decls: Tuple[str, ...] attrs: Dict[str, Any] param_count: int - ... - -class ClickFunctionWrapper(Protocol[P, O_co]): - __name__: str - __click_params__: list[click.Option] - - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> O_co: ... class _NotAttachedOption(click.Option): - """The helper class to catch grouped options which were not attached to the group - - Raises TypeError if not attached options exist. - """ - def __init__(self, param_decls: Any = ..., *, all_not_attached_options: Any, **attrs: Any) -> None: ... def handle_parse_result(self, ctx: click.Context, opts: Any, args: tuple[Any]) -> Any: ... -class _OptGroup(Generic[O_co]): - """A helper class to manage creating groups and group options via decorators - - The class provides two decorator-methods: `group`/`__call__` and `option`. - These decorators should be used for adding grouped options. The class have - single global instance `optgroup` that should be used in most cases. - - The example of usage:: - - ... - @optgroup('Group 1', help='option group 1') - @optgroup.option('--foo') - @optgroup.option('--bar') - @optgroup.group('Group 2', help='option group 2') - @optgroup.option('--spam') - ... - """ +_GrpType = TypeVar("_GrpType", bound=OptionGroup) +class _OptGroup: def __init__(self) -> None: ... def __call__( self, name: Optional[str] = ..., *, - help: Optional[str] = ..., - cls: Optional[Type[OptionGroup]] = ..., + help: Optional[str] = None, + cls: Optional[Type[_GrpType]] = None, **attrs: Any, - ) -> FC: - """Creates a new group and collects its options - - Creates the option group and registers all grouped options - which were added by `option` decorator. - - :param name: Group name or None for default name - :param help: Group help or None for empty help - :param cls: Option group class that should be inherited from `OptionGroup` class - :param attrs: Additional parameters of option group class - """ - ... + ) -> Union[click.Command, Callable[[AnyCallable], click.Command]]: ... + @overload + def group( + self, + name: Optional[str], + cls: type[_GrpType], + **attrs: Any, + ) -> Callable[[AnyCallable], click.Command]: ... + @overload + def group( + self, + name: str = ..., + cls: None = None, + **attrs: Any, + ) -> Callable[[AnyCallable], click.Command]: ... + @overload def group( self, name: Optional[str] = ..., *, help: Optional[str] = ..., - cls: Optional[Type[OptionGroup]] = ..., + cls: Optional[Type[_GrpType]] = None, **attrs: Any, - ) -> FC: - """The decorator creates a new group and collects its options + ) -> Union[click.Command, Callable[[AnyCallable], click.Command]]: ... + def option(self, *param_decls: Any, **attrs: Any) -> Decorator[_FC]: ... - Creates the option group and registers all grouped options - which were added by `option` decorator. - - :param name: Group name or None for default name - :param help: Group help or None for empty help - :param cls: Option group class that should be inherited from `OptionGroup` class - :param attrs: Additional parameters of option group class - """ - ... - def option(self, *param_decls: Any, **attrs: Any) -> FC: - """The decorator adds a new option to the group - - The decorator is lazy. It adds option decls and attrs. - All options will be registered by `group` decorator. - - :param param_decls: option declaration tuple - :param attrs: additional option attributes and parameters - """ - ... - -optgroup: _OptGroup[Any] = ... +optgroup: _OptGroup = ... diff --git a/typings/click_option_group/_helpers.pyi b/typings/click_option_group/_helpers.pyi deleted file mode 100644 index 7c60ec75..00000000 --- a/typings/click_option_group/_helpers.pyi +++ /dev/null @@ -1,25 +0,0 @@ -""" -This type stub file was generated by pyright. -""" - -import collections.abc as abc -from typing import List -from typing import Tuple - -import click - -FAKE_OPT_NAME_LEN = ... - -def get_callback_and_params(func) -> Tuple[abc.Callable, List[click.Option]]: - """Returns callback function and its parameters list - - :param func: decorated function or click Command - :return: (callback, params) - """ - ... - -def get_fake_option_name(name_len: int = ..., prefix: str = ...): ... -def raise_mixing_decorators_error(wrong_option: click.Option, callback: abc.Callable): ... -def resolve_wrappers(f): - """Get the underlying function behind any level of function wrappers.""" - ... diff --git a/typings/click_option_group/_version.pyi b/typings/click_option_group/_version.pyi index 8454f4b7..5d59d877 100644 --- a/typings/click_option_group/_version.pyi +++ b/typings/click_option_group/_version.pyi @@ -1,5 +1,3 @@ -""" -This type stub file was generated by pyright. -""" +"""This type stub file was generated by pyright.""" __version__ = ... diff --git a/typings/deepmerge/strategy/dict.pyi b/typings/deepmerge/strategy/dict.pyi index e1faa7f8..50087412 100644 --- a/typings/deepmerge/strategy/dict.pyi +++ b/typings/deepmerge/strategy/dict.pyi @@ -1,5 +1,5 @@ -from ..merger import Merger from .core import StrategyList +from ..merger import Merger class DictStrategies(StrategyList): @staticmethod diff --git a/typings/deepmerge/strategy/list.pyi b/typings/deepmerge/strategy/list.pyi index dc338689..c3a1f0ea 100644 --- a/typings/deepmerge/strategy/list.pyi +++ b/typings/deepmerge/strategy/list.pyi @@ -1,5 +1,5 @@ -from ..merger import Merger from .core import StrategyList +from ..merger import Merger class ListStrategies(StrategyList): NAME: str = ... diff --git a/typings/deepmerge/strategy/set.pyi b/typings/deepmerge/strategy/set.pyi index 23cb3e64..bda98a37 100644 --- a/typings/deepmerge/strategy/set.pyi +++ b/typings/deepmerge/strategy/set.pyi @@ -1,7 +1,7 @@ from typing import Any -from ..merger import Merger from .core import StrategyList +from ..merger import Merger class SetStrategies(StrategyList): NAME = ... diff --git a/typings/docker/__init__.pyi b/typings/docker/__init__.pyi new file mode 100644 index 00000000..308a41fa --- /dev/null +++ b/typings/docker/__init__.pyi @@ -0,0 +1,2 @@ +from .client import DockerClient as DockerClient +from .client import from_env as from_env diff --git a/typings/docker/api/__init__.pyi b/typings/docker/api/__init__.pyi new file mode 100644 index 00000000..0b38db71 --- /dev/null +++ b/typings/docker/api/__init__.pyi @@ -0,0 +1 @@ +from .client import APIClient as APIClient diff --git a/typings/docker/api/build.pyi b/typings/docker/api/build.pyi new file mode 100644 index 00000000..2b6fb5ec --- /dev/null +++ b/typings/docker/api/build.pyi @@ -0,0 +1,150 @@ +"""This type stub file was generated by pyright.""" + +from .. import utils + +log = ... + +class BuildApiMixin: + def build( + self, + path=..., + tag=..., + quiet=..., + fileobj=..., + nocache=..., + rm=..., + timeout=..., + custom_context=..., + encoding=..., + pull=..., + forcerm=..., + dockerfile=..., + container_limits=..., + decode=..., + buildargs=..., + gzip=..., + shmsize=..., + labels=..., + cache_from=..., + target=..., + network_mode=..., + squash=..., + extra_hosts=..., + platform=..., + isolation=..., + use_config_proxy=..., + ): + """Similar to the ``docker build`` command. Either ``path`` or ``fileobj`` + needs to be set. ``path`` can be a local path (to a directory + containing a Dockerfile) or a remote URL. ``fileobj`` must be a + readable file-like object to a Dockerfile. + + If you have a tar file for the Docker build context (including a + Dockerfile) already, pass a readable file-like object to ``fileobj`` + and also pass ``custom_context=True``. If the stream is compressed + also, set ``encoding`` to the correct value (e.g ``gzip``). + + Example: + >>> from io import BytesIO + >>> from docker import APIClient + >>> dockerfile = ''' + ... # Shared Volume + ... FROM busybox:buildroot-2014.02 + ... VOLUME /data + ... CMD ["/bin/sh"] + ... ''' + >>> f = BytesIO(dockerfile.encode('utf-8')) + >>> cli = APIClient(base_url='tcp://127.0.0.1:2375') + >>> response = [line for line in cli.build( + ... fileobj=f, rm=True, tag='yourname/volume' + ... )] + >>> response + ['{"stream":" ---\\u003e a9eb17255234\\n"}', + '{"stream":"Step 1 : VOLUME /data\\n"}', + '{"stream":" ---\\u003e Running in abdc1e6896c6\\n"}', + '{"stream":" ---\\u003e 713bca62012e\\n"}', + '{"stream":"Removing intermediate container abdc1e6896c6\\n"}', + '{"stream":"Step 2 : CMD [\\"/bin/sh\\"]\\n"}', + '{"stream":" ---\\u003e Running in dba30f2a1a7e\\n"}', + '{"stream":" ---\\u003e 032b8b2855fc\\n"}', + '{"stream":"Removing intermediate container dba30f2a1a7e\\n"}', + '{"stream":"Successfully built 032b8b2855fc\\n"}'] + + Args: + path (str): Path to the directory containing the Dockerfile + fileobj: A file object to use as the Dockerfile. (Or a file-like + object) + tag (str): A tag to add to the final image + quiet (bool): Whether to return the status + nocache (bool): Don't use the cache when set to ``True`` + rm (bool): Remove intermediate containers. The ``docker build`` + command now defaults to ``--rm=true``, but we have kept the old + default of `False` to preserve backward compatibility + timeout (int): HTTP timeout + custom_context (bool): Optional if using ``fileobj`` + encoding (str): The encoding for a stream. Set to ``gzip`` for + compressing + pull (bool): Downloads any updates to the FROM image in Dockerfiles + forcerm (bool): Always remove intermediate containers, even after + unsuccessful builds + dockerfile (str): path within the build context to the Dockerfile + gzip (bool): If set to ``True``, gzip compression/encoding is used + buildargs (dict): A dictionary of build arguments + container_limits (dict): A dictionary of limits applied to each + container created by the build process. Valid keys: + + - memory (int): set memory limit for build + - memswap (int): Total memory (memory + swap), -1 to disable + swap + - cpushares (int): CPU shares (relative weight) + - cpusetcpus (str): CPUs in which to allow execution, e.g., + ``"0-3"``, ``"0,1"`` + decode (bool): If set to ``True``, the returned stream will be + decoded into dicts on the fly. Default ``False`` + shmsize (int): Size of `/dev/shm` in bytes. The size must be + greater than 0. If omitted the system uses 64MB + labels (dict): A dictionary of labels to set on the image + cache_from (:py:class:`list`): A list of images used for build + cache resolution + target (str): Name of the build-stage to build in a multi-stage + Dockerfile + network_mode (str): networking mode for the run commands during + build + squash (bool): Squash the resulting images layers into a + single layer. + extra_hosts (dict): Extra hosts to add to /etc/hosts in building + containers, as a mapping of hostname to IP address. + platform (str): Platform in the format ``os[/arch[/variant]]`` + isolation (str): Isolation technology used during build. + Default: `None`. + use_config_proxy (bool): If ``True``, and if the docker client + configuration file (``~/.docker/config.json`` by default) + contains a proxy configuration, the corresponding environment + variables will be set in the container being built. + + Returns: + A generator for the build output. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + ``TypeError`` + If neither ``path`` nor ``fileobj`` is specified. + """ + ... + @utils.minimum_version("1.31") + def prune_builds(self): + """Delete the builder cache. + + Returns: + (dict): A dictionary containing information about the operation's + result. The ``SpaceReclaimed`` key indicates the amount of + bytes of disk space reclaimed. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + +def process_dockerfile(dockerfile, path): ... diff --git a/typings/docker/api/client.pyi b/typings/docker/api/client.pyi new file mode 100644 index 00000000..1a250599 --- /dev/null +++ b/typings/docker/api/client.pyi @@ -0,0 +1,96 @@ +"""This type stub file was generated by pyright.""" + +import requests + +from .build import BuildApiMixin +from .config import ConfigApiMixin +from .container import ContainerApiMixin +from .daemon import DaemonApiMixin +from .exec_api import ExecApiMixin +from .image import ImageApiMixin +from .network import NetworkApiMixin +from .plugin import PluginApiMixin +from .secret import SecretApiMixin +from .service import ServiceApiMixin +from .swarm import SwarmApiMixin +from .volume import VolumeApiMixin + +class APIClient( + requests.Session, + BuildApiMixin, + ConfigApiMixin, + ContainerApiMixin, + DaemonApiMixin, + ExecApiMixin, + ImageApiMixin, + NetworkApiMixin, + PluginApiMixin, + SecretApiMixin, + ServiceApiMixin, + SwarmApiMixin, + VolumeApiMixin, +): + """A low-level client for the Docker Engine API. + + Example: + >>> import docker + >>> client = docker.APIClient(base_url='unix://var/run/docker.sock') + >>> client.version() + {u'ApiVersion': u'1.33', + u'Arch': u'amd64', + u'BuildTime': u'2017-11-19T18:46:37.000000000+00:00', + u'GitCommit': u'f4ffd2511c', + u'GoVersion': u'go1.9.2', + u'KernelVersion': u'4.14.3-1-ARCH', + u'MinAPIVersion': u'1.12', + u'Os': u'linux', + u'Version': u'17.10.0-ce'} + + Args: + base_url (str): URL to the Docker server. For example, + ``unix:///var/run/docker.sock`` or ``tcp://127.0.0.1:1234``. + version (str): The version of the API to use. Set to ``auto`` to + automatically detect the server's version. Default: ``1.35`` + timeout (int): Default timeout for API calls, in seconds. + tls (bool or :py:class:`~docker.tls.TLSConfig`): Enable TLS. Pass + ``True`` to enable it with default options, or pass a + :py:class:`~docker.tls.TLSConfig` object to use custom + configuration. + user_agent (str): Set a custom user agent for requests to the server. + credstore_env (dict): Override environment variables when calling the + credential store process. + use_ssh_client (bool): If set to `True`, an ssh connection is made + via shelling out to the ssh client. Ensure the ssh client is + installed and configured on the host. + max_pool_size (int): The maximum number of connections + to save in the pool. + """ + + __attrs__ = ... + def __init__( + self, + base_url=..., + version=..., + timeout=..., + tls=..., + user_agent=..., + num_pools=..., + credstore_env=..., + use_ssh_client=..., + max_pool_size=..., + ) -> None: ... + def get_adapter(self, url): ... + @property + def api_version(self): ... + def reload_config(self, dockercfg_path=...): # -> None: + """Force a reload of the auth configuration. + + Args: + dockercfg_path (str): Use a custom path for the Docker config file + (default ``$HOME/.docker/config.json`` if present, + otherwise ``$HOME/.dockercfg``) + + Returns: + None + """ + ... diff --git a/typings/docker/api/config.pyi b/typings/docker/api/config.pyi new file mode 100644 index 00000000..5acf32ef --- /dev/null +++ b/typings/docker/api/config.pyi @@ -0,0 +1,61 @@ +"""This type stub file was generated by pyright.""" + +from .. import utils + +class ConfigApiMixin: + @utils.minimum_version("1.30") + def create_config(self, name, data, labels=..., templating=...): + """Create a config. + + Args: + name (string): Name of the config + data (bytes): Config data to be stored + labels (dict): A mapping of labels to assign to the config + templating (dict): dictionary containing the name of the + templating driver to be used expressed as + { name: } + + Returns (dict): ID of the newly created config + """ + ... + @utils.minimum_version("1.30") + @utils.check_resource("id") + def inspect_config(self, id): + """Retrieve config metadata. + + Args: + id (string): Full ID of the config to inspect + + Returns (dict): A dictionary of metadata + + Raises: + :py:class:`docker.errors.NotFound` + if no config with that ID exists + """ + ... + @utils.minimum_version("1.30") + @utils.check_resource("id") + def remove_config(self, id): # -> Literal[True]: + """Remove a config. + + Args: + id (string): Full ID of the config to remove + + Returns (boolean): True if successful + + Raises: + :py:class:`docker.errors.NotFound` + if no config with that ID exists + """ + ... + @utils.minimum_version("1.30") + def configs(self, filters=...): + """List configs. + + Args: + filters (dict): A map of filters to process on the configs + list. Available filters: ``names`` + + Returns (list): A list of configs + """ + ... diff --git a/typings/docker/api/container.pyi b/typings/docker/api/container.pyi new file mode 100644 index 00000000..0e8a0891 --- /dev/null +++ b/typings/docker/api/container.pyi @@ -0,0 +1,962 @@ +"""This type stub file was generated by pyright.""" + +from .. import utils + +class ContainerApiMixin: + @utils.check_resource("container") + def attach(self, container, stdout=..., stderr=..., stream=..., logs=..., demux=...): # -> CancellableStream: + """Attach to a container. + + The ``.logs()`` function is a wrapper around this method, which you can + use instead if you want to fetch/stream container output without first + retrieving the entire backlog. + + Args: + container (str): The container to attach to. + stdout (bool): Include stdout. + stderr (bool): Include stderr. + stream (bool): Return container output progressively as an iterator + of strings, rather than a single string. + logs (bool): Include the container's previous output. + demux (bool): Keep stdout and stderr separate. + + Returns: + By default, the container's output as a single string (two if + ``demux=True``: one for stdout and one for stderr). + + If ``stream=True``, an iterator of output strings. If + ``demux=True``, two iterators are returned: one for stdout and one + for stderr. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def attach_socket(self, container, params=..., ws=...): + """Like ``attach``, but returns the underlying socket-like object for the + HTTP request. + + Args: + container (str): The container to attach to. + params (dict): Dictionary of request parameters (e.g. ``stdout``, + ``stderr``, ``stream``). + For ``detachKeys``, ~/.docker/config.json is used by default. + ws (bool): Use websockets instead of raw HTTP. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def commit(self, container, repository=..., tag=..., message=..., author=..., changes=..., conf=...): + """Commit a container to an image. Similar to the ``docker commit`` + command. + + Args: + container (str): The image hash of the container + repository (str): The repository to push the image to + tag (str): The tag to push + message (str): A commit message + author (str): The name of the author + changes (str): Dockerfile instructions to apply while committing + conf (dict): The configuration for the container. See the + `Engine API documentation + `_ + for full details. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def containers( + self, quiet=..., all=..., trunc=..., latest=..., since=..., before=..., limit=..., size=..., filters=... + ): # -> list[dict[str, Unknown]]: + """List containers. Similar to the ``docker ps`` command. + + Args: + quiet (bool): Only display numeric Ids + all (bool): Show all containers. Only running containers are shown + by default + trunc (bool): Truncate output + latest (bool): Show only the latest created container, include + non-running ones. + since (str): Show only containers created since Id or Name, include + non-running ones + before (str): Show only container created before Id or Name, + include non-running ones + limit (int): Show `limit` last created containers, include + non-running ones + size (bool): Display sizes + filters (dict): Filters to be processed on the image list. + Available filters: + + - `exited` (int): Only containers with specified exit code + - `status` (str): One of ``restarting``, ``running``, + ``paused``, ``exited`` + - `label` (str|list): format either ``"key"``, ``"key=value"`` + or a list of such. + - `id` (str): The id of the container. + - `name` (str): The name of the container. + - `ancestor` (str): Filter by container ancestor. Format of + ``[:tag]``, ````, or + ````. + - `before` (str): Only containers created before a particular + container. Give the container name or id. + - `since` (str): Only containers created after a particular + container. Give container name or id. + + A comprehensive list can be found in the documentation for + `docker ps + `_. + + Returns: + A list of dicts, one per container + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def create_container( + self, + image, + command=..., + hostname=..., + user=..., + detach=..., + stdin_open=..., + tty=..., + ports=..., + environment=..., + volumes=..., + network_disabled=..., + name=..., + entrypoint=..., + working_dir=..., + domainname=..., + host_config=..., + mac_address=..., + labels=..., + stop_signal=..., + networking_config=..., + healthcheck=..., + stop_timeout=..., + runtime=..., + use_config_proxy=..., + platform=..., + ): + """Creates a container. Parameters are similar to those for the ``docker + run`` command except it doesn't support the attach options (``-a``). + + The arguments that are passed directly to this function are + host-independent configuration options. Host-specific configuration + is passed with the `host_config` argument. You'll normally want to + use this method in combination with the :py:meth:`create_host_config` + method to generate ``host_config``. + + **Port bindings** + + Port binding is done in two parts: first, provide a list of ports to + open inside the container with the ``ports`` parameter, then declare + bindings with the ``host_config`` parameter. For example: + + .. code-block:: python + + container_id = client.api.create_container( + 'busybox', 'ls', ports=[1111, 2222], + host_config=client.api.create_host_config(port_bindings={ + 1111: 4567, + 2222: None + }) + ) + + + You can limit the host address on which the port will be exposed like + such: + + .. code-block:: python + + client.api.create_host_config( + port_bindings={1111: ('127.0.0.1', 4567)} + ) + + Or without host port assignment: + + .. code-block:: python + + client.api.create_host_config(port_bindings={1111: ('127.0.0.1',)}) + + If you wish to use UDP instead of TCP (default), you need to declare + ports as such in both the config and host config: + + .. code-block:: python + + container_id = client.api.create_container( + 'busybox', 'ls', ports=[(1111, 'udp'), 2222], + host_config=client.api.create_host_config(port_bindings={ + '1111/udp': 4567, 2222: None + }) + ) + + To bind multiple host ports to a single container port, use the + following syntax: + + .. code-block:: python + + client.api.create_host_config(port_bindings={ + 1111: [1234, 4567] + }) + + You can also bind multiple IPs to a single container port: + + .. code-block:: python + + client.api.create_host_config(port_bindings={ + 1111: [ + ('192.168.0.100', 1234), + ('192.168.0.101', 1234) + ] + }) + + **Using volumes** + + Volume declaration is done in two parts. Provide a list of + paths to use as mountpoints inside the container with the + ``volumes`` parameter, and declare mappings from paths on the host + in the ``host_config`` section. + + .. code-block:: python + + container_id = client.api.create_container( + 'busybox', 'ls', volumes=['/mnt/vol1', '/mnt/vol2'], + host_config=client.api.create_host_config(binds={ + '/home/user1/': { + 'bind': '/mnt/vol2', + 'mode': 'rw', + }, + '/var/www': { + 'bind': '/mnt/vol1', + 'mode': 'ro', + } + }) + ) + + You can alternatively specify binds as a list. This code is equivalent + to the example above: + + .. code-block:: python + + container_id = client.api.create_container( + 'busybox', 'ls', volumes=['/mnt/vol1', '/mnt/vol2'], + host_config=client.api.create_host_config(binds=[ + '/home/user1/:/mnt/vol2', + '/var/www:/mnt/vol1:ro', + ]) + ) + + **Networking** + + You can specify networks to connect the container to by using the + ``networking_config`` parameter. At the time of creation, you can + only connect a container to a single networking, but you + can create more connections by using + :py:meth:`~connect_container_to_network`. + + For example: + + .. code-block:: python + + networking_config = client.api.create_networking_config({ + 'network1': client.api.create_endpoint_config( + ipv4_address='172.28.0.124', + aliases=['foo', 'bar'], + links=['container2'] + ) + }) + + ctnr = client.api.create_container( + img, command, networking_config=networking_config + ) + + Args: + image (str): The image to run + command (str or list): The command to be run in the container + hostname (str): Optional hostname for the container + user (str or int): Username or UID + detach (bool): Detached mode: run container in the background and + return container ID + stdin_open (bool): Keep STDIN open even if not attached + tty (bool): Allocate a pseudo-TTY + ports (list of ints): A list of port numbers + environment (dict or list): A dictionary or a list of strings in + the following format ``["PASSWORD=xxx"]`` or + ``{"PASSWORD": "xxx"}``. + volumes (str or list): List of paths inside the container to use + as volumes. + network_disabled (bool): Disable networking + name (str): A name for the container + entrypoint (str or list): An entrypoint + working_dir (str): Path to the working directory + domainname (str): The domain name to use for the container + host_config (dict): A dictionary created with + :py:meth:`create_host_config`. + mac_address (str): The Mac Address to assign the container + labels (dict or list): A dictionary of name-value labels (e.g. + ``{"label1": "value1", "label2": "value2"}``) or a list of + names of labels to set with empty values (e.g. + ``["label1", "label2"]``) + stop_signal (str): The stop signal to use to stop the container + (e.g. ``SIGINT``). + stop_timeout (int): Timeout to stop the container, in seconds. + Default: 10 + networking_config (dict): A networking configuration generated + by :py:meth:`create_networking_config`. + runtime (str): Runtime to use with this container. + healthcheck (dict): Specify a test to perform to check that the + container is healthy. + use_config_proxy (bool): If ``True``, and if the docker client + configuration file (``~/.docker/config.json`` by default) + contains a proxy configuration, the corresponding environment + variables will be set in the container being created. + platform (str): Platform in the format ``os[/arch[/variant]]``. + + Returns: + A dictionary with an image 'Id' key and a 'Warnings' key. + + Raises: + :py:class:`docker.errors.ImageNotFound` + If the specified image does not exist. + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def create_container_config(self, *args, **kwargs): ... + def create_container_from_config(self, config, name=..., platform=...): ... + def create_host_config(self, *args, **kwargs): # -> HostConfig: + """Create a dictionary for the ``host_config`` argument to + :py:meth:`create_container`. + + Args: + auto_remove (bool): enable auto-removal of the container on daemon + side when the container's process exits. + binds (dict): Volumes to bind. See :py:meth:`create_container` + for more information. + blkio_weight_device: Block IO weight (relative device weight) in + the form of: ``[{"Path": "device_path", "Weight": weight}]``. + blkio_weight: Block IO weight (relative weight), accepts a weight + value between 10 and 1000. + cap_add (list of str): Add kernel capabilities. For example, + ``["SYS_ADMIN", "MKNOD"]``. + cap_drop (list of str): Drop kernel capabilities. + cpu_period (int): The length of a CPU period in microseconds. + cpu_quota (int): Microseconds of CPU time that the container can + get in a CPU period. + cpu_shares (int): CPU shares (relative weight). + cpuset_cpus (str): CPUs in which to allow execution (``0-3``, + ``0,1``). + cpuset_mems (str): Memory nodes (MEMs) in which to allow execution + (``0-3``, ``0,1``). Only effective on NUMA systems. + device_cgroup_rules (:py:class:`list`): A list of cgroup rules to + apply to the container. + device_read_bps: Limit read rate (bytes per second) from a device + in the form of: `[{"Path": "device_path", "Rate": rate}]` + device_read_iops: Limit read rate (IO per second) from a device. + device_write_bps: Limit write rate (bytes per second) from a + device. + device_write_iops: Limit write rate (IO per second) from a device. + devices (:py:class:`list`): Expose host devices to the container, + as a list of strings in the form + ``::``. + + For example, ``/dev/sda:/dev/xvda:rwm`` allows the container + to have read-write access to the host's ``/dev/sda`` via a + node named ``/dev/xvda`` inside the container. + device_requests (:py:class:`list`): Expose host resources such as + GPUs to the container, as a list of + :py:class:`docker.types.DeviceRequest` instances. + dns (:py:class:`list`): Set custom DNS servers. + dns_opt (:py:class:`list`): Additional options to be added to the + container's ``resolv.conf`` file + dns_search (:py:class:`list`): DNS search domains. + extra_hosts (dict): Additional hostnames to resolve inside the + container, as a mapping of hostname to IP address. + group_add (:py:class:`list`): List of additional group names and/or + IDs that the container process will run as. + init (bool): Run an init inside the container that forwards + signals and reaps processes + ipc_mode (str): Set the IPC mode for the container. + isolation (str): Isolation technology to use. Default: ``None``. + links (dict): Mapping of links using the + ``{'container': 'alias'}`` format. The alias is optional. + Containers declared in this dict will be linked to the new + container using the provided alias. Default: ``None``. + log_config (LogConfig): Logging configuration + lxc_conf (dict): LXC config. + mem_limit (float or str): Memory limit. Accepts float values + (which represent the memory limit of the created container in + bytes) or a string with a units identification char + (``100000b``, ``1000k``, ``128m``, ``1g``). If a string is + specified without a units character, bytes are assumed as an + mem_reservation (float or str): Memory soft limit. + mem_swappiness (int): Tune a container's memory swappiness + behavior. Accepts number between 0 and 100. + memswap_limit (str or int): Maximum amount of memory + swap a + container is allowed to consume. + mounts (:py:class:`list`): Specification for mounts to be added to + the container. More powerful alternative to ``binds``. Each + item in the list is expected to be a + :py:class:`docker.types.Mount` object. + network_mode (str): One of: + + - ``bridge`` Create a new network stack for the container on + the bridge network. + - ``none`` No networking for this container. + - ``container:`` Reuse another container's network + stack. + - ``host`` Use the host network stack. + This mode is incompatible with ``port_bindings``. + + oom_kill_disable (bool): Whether to disable OOM killer. + oom_score_adj (int): An integer value containing the score given + to the container in order to tune OOM killer preferences. + pid_mode (str): If set to ``host``, use the host PID namespace + inside the container. + pids_limit (int): Tune a container's pids limit. Set ``-1`` for + unlimited. + port_bindings (dict): See :py:meth:`create_container` + for more information. + Imcompatible with ``host`` in ``network_mode``. + privileged (bool): Give extended privileges to this container. + publish_all_ports (bool): Publish all ports to the host. + read_only (bool): Mount the container's root filesystem as read + only. + restart_policy (dict): Restart the container when it exits. + Configured as a dictionary with keys: + + - ``Name`` One of ``on-failure``, or ``always``. + - ``MaximumRetryCount`` Number of times to restart the + container on failure. + security_opt (:py:class:`list`): A list of string values to + customize labels for MLS systems, such as SELinux. + shm_size (str or int): Size of /dev/shm (e.g. ``1G``). + storage_opt (dict): Storage driver options per container as a + key-value mapping. + sysctls (dict): Kernel parameters to set in the container. + tmpfs (dict): Temporary filesystems to mount, as a dictionary + mapping a path inside the container to options for that path. + + For example: + + .. code-block:: python + + { + '/mnt/vol2': '', + '/mnt/vol1': 'size=3G,uid=1000' + } + + ulimits (:py:class:`list`): Ulimits to set inside the container, + as a list of :py:class:`docker.types.Ulimit` instances. + userns_mode (str): Sets the user namespace mode for the container + when user namespace remapping option is enabled. Supported + values are: ``host`` + uts_mode (str): Sets the UTS namespace mode for the container. + Supported values are: ``host`` + volumes_from (:py:class:`list`): List of container names or IDs to + get volumes from. + runtime (str): Runtime to use with this container. + + + Returns: + (dict) A dictionary which can be passed to the ``host_config`` + argument to :py:meth:`create_container`. + + Example: + >>> client.api.create_host_config( + ... privileged=True, + ... cap_drop=['MKNOD'], + ... volumes_from=['nostalgic_newton'], + ... ) + {'CapDrop': ['MKNOD'], 'LxcConf': None, 'Privileged': True, + 'VolumesFrom': ['nostalgic_newton'], 'PublishAllPorts': False} + """ + ... + def create_networking_config(self, *args, **kwargs): # -> NetworkingConfig: + """Create a networking config dictionary to be used as the + ``networking_config`` parameter in :py:meth:`create_container`. + + Args: + endpoints_config (dict): A dictionary mapping network names to + endpoint configurations generated by + :py:meth:`create_endpoint_config`. + + Returns: + (dict) A networking config. + + Example: + >>> client.api.create_network('network1') + >>> networking_config = client.api.create_networking_config({ + 'network1': client.api.create_endpoint_config() + }) + >>> container = client.api.create_container( + img, command, networking_config=networking_config + ) + + """ + ... + def create_endpoint_config(self, *args, **kwargs): # -> EndpointConfig: + """Create an endpoint config dictionary to be used with + :py:meth:`create_networking_config`. + + Args: + aliases (:py:class:`list`): A list of aliases for this endpoint. + Names in that list can be used within the network to reach the + container. Defaults to ``None``. + links (dict): Mapping of links for this endpoint using the + ``{'container': 'alias'}`` format. The alias is optional. + Containers declared in this dict will be linked to this + container using the provided alias. Defaults to ``None``. + ipv4_address (str): The IP address of this container on the + network, using the IPv4 protocol. Defaults to ``None``. + ipv6_address (str): The IP address of this container on the + network, using the IPv6 protocol. Defaults to ``None``. + link_local_ips (:py:class:`list`): A list of link-local (IPv4/IPv6) + addresses. + driver_opt (dict): A dictionary of options to provide to the + network driver. Defaults to ``None``. + + Returns: + (dict) An endpoint config. + + Example: + >>> endpoint_config = client.api.create_endpoint_config( + aliases=['web', 'app'], + links={'app_db': 'db', 'another': None}, + ipv4_address='132.65.0.123' + ) + + """ + ... + @utils.check_resource("container") + def diff(self, container): + """Inspect changes on a container's filesystem. + + Args: + container (str): The container to diff + + Returns: + (list) A list of dictionaries containing the attributes `Path` + and `Kind`. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def export(self, container, chunk_size=...): + """Export the contents of a filesystem as a tar archive. + + Args: + container (str): The container to export + chunk_size (int): The number of bytes returned by each iteration + of the generator. If ``None``, data will be streamed as it is + received. Default: 2 MB + + Returns: + (generator): The archived filesystem data stream + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def get_archive(self, container, path, chunk_size=..., encode_stream=...): # -> tuple[Unknown, Any | None]: + """Retrieve a file or folder from a container in the form of a tar + archive. + + Args: + container (str): The container where the file is located + path (str): Path to the file or folder to retrieve + chunk_size (int): The number of bytes returned by each iteration + of the generator. If ``None``, data will be streamed as it is + received. Default: 2 MB + encode_stream (bool): Determines if data should be encoded + (gzip-compressed) during transmission. Default: False + + Returns: + (tuple): First element is a raw tar data stream. Second element is + a dict containing ``stat`` information on the specified ``path``. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> c = docker.APIClient() + >>> f = open('./sh_bin.tar', 'wb') + >>> bits, stat = c.api.get_archive(container, '/bin/sh') + >>> print(stat) + {'name': 'sh', 'size': 1075464, 'mode': 493, + 'mtime': '2018-10-01T15:37:48-07:00', 'linkTarget': ''} + >>> for chunk in bits: + ... f.write(chunk) + >>> f.close() + """ + ... + @utils.check_resource("container") + def inspect_container(self, container): + """Identical to the `docker inspect` command, but only for containers. + + Args: + container (str): The container to inspect + + Returns: + (dict): Similar to the output of `docker inspect`, but as a + single dict + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def kill(self, container, signal=...): # -> None: + """Kill a container or send a signal to a container. + + Args: + container (str): The container to kill + signal (str or int): The signal to send. Defaults to ``SIGKILL`` + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def logs( + self, container, stdout=..., stderr=..., stream=..., timestamps=..., tail=..., since=..., follow=..., until=... + ): # -> CancellableStream: + """Get logs from a container. Similar to the ``docker logs`` command. + + The ``stream`` parameter makes the ``logs`` function return a blocking + generator you can iterate over to retrieve log output as it happens. + + Args: + container (str): The container to get logs from + stdout (bool): Get ``STDOUT``. Default ``True`` + stderr (bool): Get ``STDERR``. Default ``True`` + stream (bool): Stream the response. Default ``False`` + timestamps (bool): Show timestamps. Default ``False`` + tail (str or int): Output specified number of lines at the end of + logs. Either an integer of number of lines or the string + ``all``. Default ``all`` + since (datetime, int, or float): Show logs since a given datetime, + integer epoch (in seconds) or float (in fractional seconds) + follow (bool): Follow log output. Default ``False`` + until (datetime, int, or float): Show logs that occurred before + the given datetime, integer epoch (in seconds), or + float (in fractional seconds) + + Returns: + (generator or str) + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def pause(self, container): # -> None: + """Pauses all processes within a container. + + Args: + container (str): The container to pause + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def port(self, container, private_port): # -> None: + """Lookup the public-facing port that is NAT-ed to ``private_port``. + Identical to the ``docker port`` command. + + Args: + container (str): The container to look up + private_port (int): The private port to inspect + + Returns: + (list of dict): The mapping for the host ports + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + .. code-block:: bash + + $ docker run -d -p 80:80 ubuntu:14.04 /bin/sleep 30 + 7174d6347063a83f412fad6124c99cffd25ffe1a0807eb4b7f9cec76ac8cb43b + + .. code-block:: python + + >>> client.api.port('7174d6347063', 80) + [{'HostIp': '0.0.0.0', 'HostPort': '80'}] + """ + ... + @utils.check_resource("container") + def put_archive(self, container, path, data): + """Insert a file or folder in an existing container using a tar archive as + source. + + Args: + container (str): The container where the file(s) will be extracted + path (str): Path inside the container where the file(s) will be + extracted. Must exist. + data (bytes or stream): tar data to be extracted + + Returns: + (bool): True if the call succeeds. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.minimum_version("1.25") + def prune_containers(self, filters=...): + """Delete stopped containers. + + Args: + filters (dict): Filters to process on the prune list. + + Returns: + (dict): A dict containing a list of deleted container IDs and + the amount of disk space reclaimed in bytes. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def remove_container(self, container, v=..., link=..., force=...): # -> None: + """Remove a container. Similar to the ``docker rm`` command. + + Args: + container (str): The container to remove + v (bool): Remove the volumes associated with the container + link (bool): Remove the specified link and not the underlying + container + force (bool): Force the removal of a running container (uses + ``SIGKILL``) + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def rename(self, container, name): # -> None: + """Rename a container. Similar to the ``docker rename`` command. + + Args: + container (str): ID of the container to rename + name (str): New name for the container + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def resize(self, container, height, width): # -> None: + """Resize the tty session. + + Args: + container (str or dict): The container to resize + height (int): Height of tty session + width (int): Width of tty session + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def restart(self, container, timeout=...): # -> None: + """Restart a container. Similar to the ``docker restart`` command. + + Args: + container (str or dict): The container to restart. If a dict, the + ``Id`` key is used. + timeout (int): Number of seconds to try to stop for before killing + the container. Once killed it will then be restarted. Default + is 10 seconds. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def start(self, container, *args, **kwargs): # -> None: + """Start a container. Similar to the ``docker start`` command, but + doesn't support attach options. + + **Deprecation warning:** Passing configuration options in ``start`` is + no longer supported. Users are expected to provide host config options + in the ``host_config`` parameter of + :py:meth:`~ContainerApiMixin.create_container`. + + + Args: + container (str): The container to start + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + :py:class:`docker.errors.DeprecatedMethod` + If any argument besides ``container`` are provided. + + Example: + >>> container = client.api.create_container( + ... image='busybox:latest', + ... command='/bin/sleep 30') + >>> client.api.start(container=container.get('Id')) + """ + ... + @utils.check_resource("container") + def stats(self, container, decode=..., stream=..., one_shot=...): + """Stream statistics for a specific container. Similar to the + ``docker stats`` command. + + Args: + container (str): The container to stream statistics from + decode (bool): If set to true, stream will be decoded into dicts + on the fly. Only applicable if ``stream`` is True. + False by default. + stream (bool): If set to false, only the current stats will be + returned instead of a stream. True by default. + one_shot (bool): If set to true, Only get a single stat instead of + waiting for 2 cycles. Must be used with stream=false. False by + default. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + """ + ... + @utils.check_resource("container") + def stop(self, container, timeout=...): # -> None: + """Stops a container. Similar to the ``docker stop`` command. + + Args: + container (str): The container to stop + timeout (int): Timeout in seconds to wait for the container to + stop before sending a ``SIGKILL``. If None, then the + StopTimeout value of the container will be used. + Default: None + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def top(self, container, ps_args=...): + """Display the running processes of a container. + + Args: + container (str): The container to inspect + ps_args (str): An optional arguments passed to ps (e.g. ``aux``) + + Returns: + (str): The output of the top + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def unpause(self, container): # -> None: + """Unpause all processes within a container. + + Args: + container (str): The container to unpause + """ + ... + @utils.minimum_version("1.22") + @utils.check_resource("container") + def update_container( + self, + container, + blkio_weight=..., + cpu_period=..., + cpu_quota=..., + cpu_shares=..., + cpuset_cpus=..., + cpuset_mems=..., + mem_limit=..., + mem_reservation=..., + memswap_limit=..., + kernel_memory=..., + restart_policy=..., + ): + """Update resource configs of one or more containers. + + Args: + container (str): The container to inspect + blkio_weight (int): Block IO (relative weight), between 10 and 1000 + cpu_period (int): Limit CPU CFS (Completely Fair Scheduler) period + cpu_quota (int): Limit CPU CFS (Completely Fair Scheduler) quota + cpu_shares (int): CPU shares (relative weight) + cpuset_cpus (str): CPUs in which to allow execution + cpuset_mems (str): MEMs in which to allow execution + mem_limit (float or str): Memory limit + mem_reservation (float or str): Memory soft limit + memswap_limit (int or str): Total memory (memory + swap), -1 to + disable swap + kernel_memory (int or str): Kernel memory limit + restart_policy (dict): Restart policy dictionary + + Returns: + (dict): Dictionary containing a ``Warnings`` key. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("container") + def wait(self, container, timeout=..., condition=...): + """Block until a container stops, then return its exit code. Similar to + the ``docker wait`` command. + + Args: + container (str or dict): The container to wait on. If a dict, the + ``Id`` key is used. + timeout (int): Request timeout + condition (str): Wait until a container state reaches the given + condition, either ``not-running`` (default), ``next-exit``, + or ``removed`` + + Returns: + (dict): The API's response as a Python dictionary, including + the container's exit code under the ``StatusCode`` attribute. + + Raises: + :py:class:`requests.exceptions.ReadTimeout` + If the timeout is exceeded. + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... diff --git a/typings/docker/api/daemon.pyi b/typings/docker/api/daemon.pyi new file mode 100644 index 00000000..047bc663 --- /dev/null +++ b/typings/docker/api/daemon.pyi @@ -0,0 +1,115 @@ +"""This type stub file was generated by pyright.""" + +from .. import utils + +class DaemonApiMixin: + @utils.minimum_version("1.25") + def df(self): + """Get data usage information. + + Returns: + (dict): A dictionary representing different resource categories + and their respective data usage. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def events(self, since=..., until=..., filters=..., decode=...): # -> CancellableStream: + """Get real-time events from the server. Similar to the ``docker events`` + command. + + Args: + since (UTC datetime or int): Get events from this point + until (UTC datetime or int): Get events until this point + filters (dict): Filter the events by event time, container or image + decode (bool): If set to true, stream will be decoded into dicts on + the fly. False by default. + + Returns: + A :py:class:`docker.types.daemon.CancellableStream` generator + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> for event in client.events(decode=True) + ... print(event) + {u'from': u'image/with:tag', + u'id': u'container-id', + u'status': u'start', + u'time': 1423339459} + ... + + or + + >>> events = client.events() + >>> for event in events: + ... print(event) + >>> # and cancel from another thread + >>> events.close() + """ + ... + def info(self): + """Display system-wide information. Identical to the ``docker info`` + command. + + Returns: + (dict): The info as a dict + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def login( + self, username, password=..., email=..., registry=..., reauth=..., dockercfg_path=... + ): # -> dict[str, str | Unknown] | Any: + """Authenticate with a registry. Similar to the ``docker login`` command. + + Args: + username (str): The registry username + password (str): The plaintext password + email (str): The email for the registry account + registry (str): URL to the registry. E.g. + ``https://index.docker.io/v1/`` + reauth (bool): Whether or not to refresh existing authentication on + the Docker server. + dockercfg_path (str): Use a custom path for the Docker config file + (default ``$HOME/.docker/config.json`` if present, + otherwise ``$HOME/.dockercfg``) + + Returns: + (dict): The response from the login request + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def ping(self): + """Checks the server is responsive. An exception will be raised if it + isn't responding. + + Returns: + (bool) The response from the server. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def version(self, api_version=...): + """Returns version information from the server. Similar to the ``docker + version`` command. + + Returns: + (dict): The server version information + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... diff --git a/typings/docker/api/exec_api.pyi b/typings/docker/api/exec_api.pyi new file mode 100644 index 00000000..1543eb61 --- /dev/null +++ b/typings/docker/api/exec_api.pyi @@ -0,0 +1,100 @@ +"""This type stub file was generated by pyright.""" + +from .. import utils + +class ExecApiMixin: + @utils.check_resource("container") + def exec_create( + self, + container, + cmd, + stdout=..., + stderr=..., + stdin=..., + tty=..., + privileged=..., + user=..., + environment=..., + workdir=..., + detach_keys=..., + ): + """Sets up an exec instance in a running container. + + Args: + container (str): Target container where exec instance will be + created + cmd (str or list): Command to be executed + stdout (bool): Attach to stdout. Default: ``True`` + stderr (bool): Attach to stderr. Default: ``True`` + stdin (bool): Attach to stdin. Default: ``False`` + tty (bool): Allocate a pseudo-TTY. Default: False + privileged (bool): Run as privileged. + user (str): User to execute command as. Default: root + environment (dict or list): A dictionary or a list of strings in + the following format ``["PASSWORD=xxx"]`` or + ``{"PASSWORD": "xxx"}``. + workdir (str): Path to working directory for this exec session + detach_keys (str): Override the key sequence for detaching + a container. Format is a single character `[a-Z]` + or `ctrl-` where `` is one of: + `a-z`, `@`, `^`, `[`, `,` or `_`. + ~/.docker/config.json is used by default. + + Returns: + (dict): A dictionary with an exec ``Id`` key. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def exec_inspect(self, exec_id): + """Return low-level information about an exec command. + + Args: + exec_id (str): ID of the exec instance + + Returns: + (dict): Dictionary of values returned by the endpoint. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def exec_resize(self, exec_id, height=..., width=...): # -> None: + """Resize the tty session used by the specified exec command. + + Args: + exec_id (str): ID of the exec instance + height (int): Height of tty session + width (int): Width of tty session + """ + ... + @utils.check_resource("exec_id") + def exec_start(self, exec_id, detach=..., tty=..., stream=..., socket=..., demux=...): # -> CancellableStream: + """Start a previously set up exec instance. + + Args: + exec_id (str): ID of the exec instance + detach (bool): If true, detach from the exec command. + Default: False + tty (bool): Allocate a pseudo-TTY. Default: False + stream (bool): Return response data progressively as an iterator + of strings, rather than a single string. + socket (bool): Return the connection socket to allow custom + read/write operations. Must be closed by the caller when done. + demux (bool): Return stdout and stderr separately + + Returns: + (generator or str or tuple): If ``stream=True``, a generator + yielding response chunks. If ``socket=True``, a socket object for + the connection. A string containing response data otherwise. If + ``demux=True``, a tuple with two elements of type byte: stdout and + stderr. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... diff --git a/typings/docker/api/image.pyi b/typings/docker/api/image.pyi new file mode 100644 index 00000000..fdb070f5 --- /dev/null +++ b/typings/docker/api/image.pyi @@ -0,0 +1,336 @@ +"""This type stub file was generated by pyright.""" + +from .. import utils + +log = ... + +class ImageApiMixin: + @utils.check_resource("image") + def get_image(self, image, chunk_size=...): + """Get a tarball of an image. Similar to the ``docker save`` command. + + Args: + image (str): Image name to get + chunk_size (int): The number of bytes returned by each iteration + of the generator. If ``None``, data will be streamed as it is + received. Default: 2 MB + + Returns: + (generator): A stream of raw archive data. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> image = client.api.get_image("busybox:latest") + >>> f = open('/tmp/busybox-latest.tar', 'wb') + >>> for chunk in image: + >>> f.write(chunk) + >>> f.close() + """ + ... + @utils.check_resource("image") + def history(self, image): + """Show the history of an image. + + Args: + image (str): The image to show history for + + Returns: + (str): The history of the image + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def images(self, name=..., quiet=..., all=..., filters=...): # -> list[Unknown]: + """List images. Similar to the ``docker images`` command. + + Args: + name (str): Only show images belonging to the repository ``name`` + quiet (bool): Only return numeric IDs as a list. + all (bool): Show intermediate image layers. By default, these are + filtered out. + filters (dict): Filters to be processed on the image list. + Available filters: + - ``dangling`` (bool) + - `label` (str|list): format either ``"key"``, ``"key=value"`` + or a list of such. + + Returns: + (dict or list): A list if ``quiet=True``, otherwise a dict. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def import_image(self, src=..., repository=..., tag=..., image=..., changes=..., stream_src=...): + """Import an image. Similar to the ``docker import`` command. + + If ``src`` is a string or unicode string, it will first be treated as a + path to a tarball on the local system. If there is an error reading + from that file, ``src`` will be treated as a URL instead to fetch the + image from. You can also pass an open file handle as ``src``, in which + case the data will be read from that file. + + If ``src`` is unset but ``image`` is set, the ``image`` parameter will + be taken as the name of an existing image to import from. + + Args: + src (str or file): Path to tarfile, URL, or file-like object + repository (str): The repository to create + tag (str): The tag to apply + image (str): Use another image like the ``FROM`` Dockerfile + parameter + """ + ... + def import_image_from_data(self, data, repository=..., tag=..., changes=...): + """Like :py:meth:`~docker.api.image.ImageApiMixin.import_image`, but + allows importing in-memory bytes data. + + Args: + data (bytes collection): Bytes collection containing valid tar data + repository (str): The repository to create + tag (str): The tag to apply + """ + ... + def import_image_from_file(self, filename, repository=..., tag=..., changes=...): + """Like :py:meth:`~docker.api.image.ImageApiMixin.import_image`, but only + supports importing from a tar file on disk. + + Args: + filename (str): Full path to a tar file. + repository (str): The repository to create + tag (str): The tag to apply + + Raises: + IOError: File does not exist. + """ + ... + def import_image_from_stream(self, stream, repository=..., tag=..., changes=...): ... + def import_image_from_url(self, url, repository=..., tag=..., changes=...): + """Like :py:meth:`~docker.api.image.ImageApiMixin.import_image`, but only + supports importing from a URL. + + Args: + url (str): A URL pointing to a tar file. + repository (str): The repository to create + tag (str): The tag to apply + """ + ... + def import_image_from_image(self, image, repository=..., tag=..., changes=...): + """Like :py:meth:`~docker.api.image.ImageApiMixin.import_image`, but only + supports importing from another image, like the ``FROM`` Dockerfile + parameter. + + Args: + image (str): Image name to import from + repository (str): The repository to create + tag (str): The tag to apply + """ + ... + @utils.check_resource("image") + def inspect_image(self, image): + """Get detailed information about an image. Similar to the ``docker + inspect`` command, but only for images. + + Args: + image (str): The image to inspect + + Returns: + (dict): Similar to the output of ``docker inspect``, but as a + single dict + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.minimum_version("1.30") + @utils.check_resource("image") + def inspect_distribution(self, image, auth_config=...): + """Get image digest and platform information by contacting the registry. + + Args: + image (str): The image name to inspect + auth_config (dict): Override the credentials that are found in the + config for this request. ``auth_config`` should contain the + ``username`` and ``password`` keys to be valid. + + Returns: + (dict): A dict containing distribution data + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def load_image(self, data, quiet=...): # -> None: + """Load an image that was previously saved using + :py:meth:`~docker.api.image.ImageApiMixin.get_image` (or ``docker + save``). Similar to ``docker load``. + + Args: + data (binary): Image data to be loaded. + quiet (boolean): Suppress progress details in response. + + Returns: + (generator): Progress output as JSON objects. Only available for + API version >= 1.23 + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.minimum_version("1.25") + def prune_images(self, filters=...): + """Delete unused images. + + Args: + filters (dict): Filters to process on the prune list. + Available filters: + - dangling (bool): When set to true (or 1), prune only + unused and untagged images. + + Returns: + (dict): A dict containing a list of deleted image IDs and + the amount of disk space reclaimed in bytes. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def pull(self, repository, tag=..., stream=..., auth_config=..., decode=..., platform=..., all_tags=...): + """Pulls an image. Similar to the ``docker pull`` command. + + Args: + repository (str): The repository to pull + tag (str): The tag to pull. If ``tag`` is ``None`` or empty, it + is set to ``latest``. + stream (bool): Stream the output as a generator. Make sure to + consume the generator, otherwise pull might get cancelled. + auth_config (dict): Override the credentials that are found in the + config for this request. ``auth_config`` should contain the + ``username`` and ``password`` keys to be valid. + decode (bool): Decode the JSON data from the server into dicts. + Only applies with ``stream=True`` + platform (str): Platform in the format ``os[/arch[/variant]]`` + all_tags (bool): Pull all image tags, the ``tag`` parameter is + ignored. + + Returns: + (generator or str): The output + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> resp = client.api.pull('busybox', stream=True, decode=True) + ... for line in resp: + ... print(json.dumps(line, indent=4)) + { + "status": "Pulling image (latest) from busybox", + "progressDetail": {}, + "id": "e72ac664f4f0" + } + { + "status": "Pulling image (latest) from busybox, endpoint: ...", + "progressDetail": {}, + "id": "e72ac664f4f0" + } + + """ + ... + def push(self, repository, tag=..., stream=..., auth_config=..., decode=...): + """Push an image or a repository to the registry. Similar to the ``docker + push`` command. + + Args: + repository (str): The repository to push to + tag (str): An optional tag to push + stream (bool): Stream the output as a blocking generator + auth_config (dict): Override the credentials that are found in the + config for this request. ``auth_config`` should contain the + ``username`` and ``password`` keys to be valid. + decode (bool): Decode the JSON data from the server into dicts. + Only applies with ``stream=True`` + + Returns: + (generator or str): The output from the server. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> resp = client.api.push( + ... 'yourname/app', + ... stream=True, + ... decode=True, + ... ) + ... for line in resp: + ... print(line) + {'status': 'Pushing repository yourname/app (1 tags)'} + {'status': 'Pushing','progressDetail': {}, 'id': '511136ea3c5a'} + {'status': 'Image already pushed, skipping', 'progressDetail':{}, + 'id': '511136ea3c5a'} + ... + + """ + ... + @utils.check_resource("image") + def remove_image(self, image, force=..., noprune=...): + """Remove an image. Similar to the ``docker rmi`` command. + + Args: + image (str): The image to remove + force (bool): Force removal of the image + noprune (bool): Do not delete untagged parents + """ + ... + def search(self, term, limit=...): + """Search for images on Docker Hub. Similar to the ``docker search`` + command. + + Args: + term (str): A term to search for. + limit (int): The maximum number of results to return. + + Returns: + (list of dicts): The response of the search. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("image") + def tag(self, image, repository, tag=..., force=...): + """Tag an image into a repository. Similar to the ``docker tag`` command. + + Args: + image (str): The image to tag + repository (str): The repository to set for the tag + tag (str): The tag name + force (bool): Force + + Returns: + (bool): ``True`` if successful + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> client.api.tag('ubuntu', 'localhost:5000/ubuntu', 'latest', + force=True) + """ + ... + +def is_file(src): ... diff --git a/typings/docker/api/network.pyi b/typings/docker/api/network.pyi new file mode 100644 index 00000000..e9788893 --- /dev/null +++ b/typings/docker/api/network.pyi @@ -0,0 +1,174 @@ +"""This type stub file was generated by pyright.""" + +from ..utils import check_resource +from ..utils import minimum_version + +class NetworkApiMixin: + def networks(self, names=..., ids=..., filters=...): + """List networks. Similar to the ``docker network ls`` command. + + Args: + names (:py:class:`list`): List of names to filter by + ids (:py:class:`list`): List of ids to filter by + filters (dict): Filters to be processed on the network list. + Available filters: + - ``driver=[]`` Matches a network's driver. + - ``label=[]``, ``label=[=]`` or a list of + such. + - ``type=["custom"|"builtin"]`` Filters networks by type. + + Returns: + (dict): List of network objects. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def create_network( + self, + name, + driver=..., + options=..., + ipam=..., + check_duplicate=..., + internal=..., + labels=..., + enable_ipv6=..., + attachable=..., + scope=..., + ingress=..., + ): + """Create a network. Similar to the ``docker network create``. + + Args: + name (str): Name of the network + driver (str): Name of the driver used to create the network + options (dict): Driver options as a key-value dictionary + ipam (IPAMConfig): Optional custom IP scheme for the network. + check_duplicate (bool): Request daemon to check for networks with + same name. Default: ``None``. + internal (bool): Restrict external access to the network. Default + ``False``. + labels (dict): Map of labels to set on the network. Default + ``None``. + enable_ipv6 (bool): Enable IPv6 on the network. Default ``False``. + attachable (bool): If enabled, and the network is in the global + scope, non-service containers on worker nodes will be able to + connect to the network. + scope (str): Specify the network's scope (``local``, ``global`` or + ``swarm``) + ingress (bool): If set, create an ingress network which provides + the routing-mesh in swarm mode. + + Returns: + (dict): The created network reference object + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + A network using the bridge driver: + + >>> client.api.create_network("network1", driver="bridge") + + You can also create more advanced networks with custom IPAM + configurations. For example, setting the subnet to + ``192.168.52.0/24`` and gateway address to ``192.168.52.254``. + + .. code-block:: python + + >>> ipam_pool = docker.types.IPAMPool( + subnet='192.168.52.0/24', + gateway='192.168.52.254' + ) + >>> ipam_config = docker.types.IPAMConfig( + pool_configs=[ipam_pool] + ) + >>> client.api.create_network("network1", driver="bridge", + ipam=ipam_config) + """ + ... + @minimum_version("1.25") + def prune_networks(self, filters=...): + """Delete unused networks. + + Args: + filters (dict): Filters to process on the prune list. + + Returns: + (dict): A dict containing a list of deleted network names and + the amount of disk space reclaimed in bytes. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @check_resource("net_id") + def remove_network(self, net_id): # -> None: + """Remove a network. Similar to the ``docker network rm`` command. + + Args: + net_id (str): The network's id + """ + ... + @check_resource("net_id") + def inspect_network(self, net_id, verbose=..., scope=...): + """Get detailed information about a network. + + Args: + net_id (str): ID of network + verbose (bool): Show the service details across the cluster in + swarm mode. + scope (str): Filter the network by scope (``swarm``, ``global`` + or ``local``). + """ + ... + @check_resource("container") + def connect_container_to_network( + self, + container, + net_id, + ipv4_address=..., + ipv6_address=..., + aliases=..., + links=..., + link_local_ips=..., + driver_opt=..., + mac_address=..., + ): # -> None: + """Connect a container to a network. + + Args: + container (str): container-id/name to be connected to the network + net_id (str): network id + aliases (:py:class:`list`): A list of aliases for this endpoint. + Names in that list can be used within the network to reach the + container. Defaults to ``None``. + links (:py:class:`list`): A list of links for this endpoint. + Containers declared in this list will be linked to this + container. Defaults to ``None``. + ipv4_address (str): The IP address of this container on the + network, using the IPv4 protocol. Defaults to ``None``. + ipv6_address (str): The IP address of this container on the + network, using the IPv6 protocol. Defaults to ``None``. + link_local_ips (:py:class:`list`): A list of link-local + (IPv4/IPv6) addresses. + mac_address (str): The MAC address of this container on the + network. Defaults to ``None``. + """ + ... + @check_resource("container") + def disconnect_container_from_network(self, container, net_id, force=...): # -> None: + """Disconnect a container from a network. + + Args: + container (str): container ID or name to be disconnected from the + network + net_id (str): network ID + force (bool): Force the container to disconnect from a network. + Default: ``False`` + """ + ... diff --git a/typings/docker/api/plugin.pyi b/typings/docker/api/plugin.pyi new file mode 100644 index 00000000..39b45de2 --- /dev/null +++ b/typings/docker/api/plugin.pyi @@ -0,0 +1,160 @@ +"""This type stub file was generated by pyright.""" + +from .. import utils + +class PluginApiMixin: + @utils.minimum_version("1.25") + @utils.check_resource("name") + def configure_plugin(self, name, options): # -> Literal[True]: + """Configure a plugin. + + Args: + name (string): The name of the plugin. The ``:latest`` tag is + optional, and is the default if omitted. + options (dict): A key-value mapping of options + + Returns: + ``True`` if successful + """ + ... + @utils.minimum_version("1.25") + def create_plugin(self, name, plugin_data_dir, gzip=...): # -> Literal[True]: + """Create a new plugin. + + Args: + name (string): The name of the plugin. The ``:latest`` tag is + optional, and is the default if omitted. + plugin_data_dir (string): Path to the plugin data directory. + Plugin data directory must contain the ``config.json`` + manifest file and the ``rootfs`` directory. + gzip (bool): Compress the context using gzip. Default: False + + Returns: + ``True`` if successful + """ + ... + @utils.minimum_version("1.25") + def disable_plugin(self, name, force=...): # -> Literal[True]: + """Disable an installed plugin. + + Args: + name (string): The name of the plugin. The ``:latest`` tag is + optional, and is the default if omitted. + force (bool): To enable the force query parameter. + + Returns: + ``True`` if successful + """ + ... + @utils.minimum_version("1.25") + def enable_plugin(self, name, timeout=...): # -> Literal[True]: + """Enable an installed plugin. + + Args: + name (string): The name of the plugin. The ``:latest`` tag is + optional, and is the default if omitted. + timeout (int): Operation timeout (in seconds). Default: 0 + + Returns: + ``True`` if successful + """ + ... + @utils.minimum_version("1.25") + def inspect_plugin(self, name): + """Retrieve plugin metadata. + + Args: + name (string): The name of the plugin. The ``:latest`` tag is + optional, and is the default if omitted. + + Returns: + A dict containing plugin info + """ + ... + @utils.minimum_version("1.25") + def pull_plugin(self, remote, privileges, name=...): + """Pull and install a plugin. After the plugin is installed, it can be + enabled using :py:meth:`~enable_plugin`. + + Args: + remote (string): Remote reference for the plugin to install. + The ``:latest`` tag is optional, and is the default if + omitted. + privileges (:py:class:`list`): A list of privileges the user + consents to grant to the plugin. Can be retrieved using + :py:meth:`~plugin_privileges`. + name (string): Local name for the pulled plugin. The + ``:latest`` tag is optional, and is the default if omitted. + + Returns: + An iterable object streaming the decoded API logs + """ + ... + @utils.minimum_version("1.25") + def plugins(self): + """Retrieve a list of installed plugins. + + Returns: + A list of dicts, one per plugin + """ + ... + @utils.minimum_version("1.25") + def plugin_privileges(self, name): + """Retrieve list of privileges to be granted to a plugin. + + Args: + name (string): Name of the remote plugin to examine. The + ``:latest`` tag is optional, and is the default if omitted. + + Returns: + A list of dictionaries representing the plugin's + permissions + + """ + ... + @utils.minimum_version("1.25") + @utils.check_resource("name") + def push_plugin(self, name): + """Push a plugin to the registry. + + Args: + name (string): Name of the plugin to upload. The ``:latest`` + tag is optional, and is the default if omitted. + + Returns: + ``True`` if successful + """ + ... + @utils.minimum_version("1.25") + @utils.check_resource("name") + def remove_plugin(self, name, force=...): # -> Literal[True]: + """Remove an installed plugin. + + Args: + name (string): Name of the plugin to remove. The ``:latest`` + tag is optional, and is the default if omitted. + force (bool): Disable the plugin before removing. This may + result in issues if the plugin is in use by a container. + + Returns: + ``True`` if successful + """ + ... + @utils.minimum_version("1.26") + @utils.check_resource("name") + def upgrade_plugin(self, name, remote, privileges): + """Upgrade an installed plugin. + + Args: + name (string): Name of the plugin to upgrade. The ``:latest`` + tag is optional and is the default if omitted. + remote (string): Remote reference to upgrade to. The + ``:latest`` tag is optional and is the default if omitted. + privileges (:py:class:`list`): A list of privileges the user + consents to grant to the plugin. Can be retrieved using + :py:meth:`~plugin_privileges`. + + Returns: + An iterable object streaming the decoded API logs + """ + ... diff --git a/typings/docker/api/secret.pyi b/typings/docker/api/secret.pyi new file mode 100644 index 00000000..23c878e4 --- /dev/null +++ b/typings/docker/api/secret.pyi @@ -0,0 +1,60 @@ +"""This type stub file was generated by pyright.""" + +from .. import utils + +class SecretApiMixin: + @utils.minimum_version("1.25") + def create_secret(self, name, data, labels=..., driver=...): + """Create a secret. + + Args: + name (string): Name of the secret + data (bytes): Secret data to be stored + labels (dict): A mapping of labels to assign to the secret + driver (DriverConfig): A custom driver configuration. If + unspecified, the default ``internal`` driver will be used + + Returns (dict): ID of the newly created secret + """ + ... + @utils.minimum_version("1.25") + @utils.check_resource("id") + def inspect_secret(self, id): + """Retrieve secret metadata. + + Args: + id (string): Full ID of the secret to inspect + + Returns (dict): A dictionary of metadata + + Raises: + :py:class:`docker.errors.NotFound` + if no secret with that ID exists + """ + ... + @utils.minimum_version("1.25") + @utils.check_resource("id") + def remove_secret(self, id): # -> Literal[True]: + """Remove a secret. + + Args: + id (string): Full ID of the secret to remove + + Returns (boolean): True if successful + + Raises: + :py:class:`docker.errors.NotFound` + if no secret with that ID exists + """ + ... + @utils.minimum_version("1.25") + def secrets(self, filters=...): + """List secrets. + + Args: + filters (dict): A map of filters to process on the secrets + list. Available filters: ``names`` + + Returns (list): A list of secrets + """ + ... diff --git a/typings/docker/api/service.pyi b/typings/docker/api/service.pyi new file mode 100644 index 00000000..389aecd4 --- /dev/null +++ b/typings/docker/api/service.pyi @@ -0,0 +1,217 @@ +"""This type stub file was generated by pyright.""" + +from .. import utils + +class ServiceApiMixin: + @utils.minimum_version("1.24") + def create_service( + self, + task_template, + name=..., + labels=..., + mode=..., + update_config=..., + networks=..., + endpoint_config=..., + endpoint_spec=..., + rollback_config=..., + ): + """Create a service. + + Args: + task_template (TaskTemplate): Specification of the task to start as + part of the new service. + name (string): User-defined name for the service. Optional. + labels (dict): A map of labels to associate with the service. + Optional. + mode (ServiceMode): Scheduling mode for the service (replicated + or global). Defaults to replicated. + update_config (UpdateConfig): Specification for the update strategy + of the service. Default: ``None`` + rollback_config (RollbackConfig): Specification for the rollback + strategy of the service. Default: ``None`` + networks (:py:class:`list`): List of network names or IDs or + :py:class:`~docker.types.NetworkAttachmentConfig` to attach the + service to. Default: ``None``. + endpoint_spec (EndpointSpec): Properties that can be configured to + access and load balance a service. Default: ``None``. + + Returns: + A dictionary containing an ``ID`` key for the newly created + service. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.minimum_version("1.24") + @utils.check_resource("service") + def inspect_service(self, service, insert_defaults=...): + """Return information about a service. + + Args: + service (str): Service name or ID. + insert_defaults (boolean): If true, default values will be merged + into the service inspect output. + + Returns: + (dict): A dictionary of the server-side representation of the + service, including all relevant properties. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.minimum_version("1.24") + @utils.check_resource("task") + def inspect_task(self, task): + """Retrieve information about a task. + + Args: + task (str): Task ID + + Returns: + (dict): Information about the task. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.minimum_version("1.24") + @utils.check_resource("service") + def remove_service(self, service): # -> Literal[True]: + """Stop and remove a service. + + Args: + service (str): Service name or ID + + Returns: + ``True`` if successful. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.minimum_version("1.24") + def services(self, filters=..., status=...): + """List services. + + Args: + filters (dict): Filters to process on the nodes list. Valid + filters: ``id``, ``name`` , ``label`` and ``mode``. + Default: ``None``. + status (bool): Include the service task count of running and + desired tasks. Default: ``None``. + + Returns: + A list of dictionaries containing data about each service. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.minimum_version("1.25") + @utils.check_resource("service") + def service_logs( + self, service, details=..., follow=..., stdout=..., stderr=..., since=..., timestamps=..., tail=..., is_tty=... + ): + """Get log stream for a service. + Note: This endpoint works only for services with the ``json-file`` + or ``journald`` logging drivers. + + Args: + service (str): ID or name of the service + details (bool): Show extra details provided to logs. + Default: ``False`` + follow (bool): Keep connection open to read logs as they are + sent by the Engine. Default: ``False`` + stdout (bool): Return logs from ``stdout``. Default: ``False`` + stderr (bool): Return logs from ``stderr``. Default: ``False`` + since (int): UNIX timestamp for the logs staring point. + Default: 0 + timestamps (bool): Add timestamps to every log line. + tail (string or int): Number of log lines to be returned, + counting from the current end of the logs. Specify an + integer or ``'all'`` to output all log lines. + Default: ``all`` + is_tty (bool): Whether the service's :py:class:`ContainerSpec` + enables the TTY option. If omitted, the method will query + the Engine for the information, causing an additional + roundtrip. + + Returns (generator): Logs for the service. + """ + ... + @utils.minimum_version("1.24") + def tasks(self, filters=...): + """Retrieve a list of tasks. + + Args: + filters (dict): A map of filters to process on the tasks list. + Valid filters: ``id``, ``name``, ``service``, ``node``, + ``label`` and ``desired-state``. + + Returns: + (:py:class:`list`): List of task dictionaries. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.minimum_version("1.24") + @utils.check_resource("service") + def update_service( + self, + service, + version, + task_template=..., + name=..., + labels=..., + mode=..., + update_config=..., + networks=..., + endpoint_config=..., + endpoint_spec=..., + fetch_current_spec=..., + rollback_config=..., + ): + """Update a service. + + Args: + service (string): A service identifier (either its name or service + ID). + version (int): The version number of the service object being + updated. This is required to avoid conflicting writes. + task_template (TaskTemplate): Specification of the updated task to + start as part of the service. + name (string): New name for the service. Optional. + labels (dict): A map of labels to associate with the service. + Optional. + mode (ServiceMode): Scheduling mode for the service (replicated + or global). Defaults to replicated. + update_config (UpdateConfig): Specification for the update strategy + of the service. Default: ``None``. + rollback_config (RollbackConfig): Specification for the rollback + strategy of the service. Default: ``None`` + networks (:py:class:`list`): List of network names or IDs or + :py:class:`~docker.types.NetworkAttachmentConfig` to attach the + service to. Default: ``None``. + endpoint_spec (EndpointSpec): Properties that can be configured to + access and load balance a service. Default: ``None``. + fetch_current_spec (boolean): Use the undefined settings from the + current specification of the service. Default: ``False`` + + Returns: + A dictionary containing a ``Warnings`` key. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... diff --git a/typings/docker/api/swarm.pyi b/typings/docker/api/swarm.pyi new file mode 100644 index 00000000..3518339c --- /dev/null +++ b/typings/docker/api/swarm.pyi @@ -0,0 +1,318 @@ +"""This type stub file was generated by pyright.""" + +from .. import utils + +log = ... + +class SwarmApiMixin: + def create_swarm_spec(self, *args, **kwargs): # -> SwarmSpec: + """Create a :py:class:`docker.types.SwarmSpec` instance that can be used + as the ``swarm_spec`` argument in + :py:meth:`~docker.api.swarm.SwarmApiMixin.init_swarm`. + + Args: + task_history_retention_limit (int): Maximum number of tasks + history stored. + snapshot_interval (int): Number of logs entries between snapshot. + keep_old_snapshots (int): Number of snapshots to keep beyond the + current snapshot. + log_entries_for_slow_followers (int): Number of log entries to + keep around to sync up slow followers after a snapshot is + created. + heartbeat_tick (int): Amount of ticks (in seconds) between each + heartbeat. + election_tick (int): Amount of ticks (in seconds) needed without a + leader to trigger a new election. + dispatcher_heartbeat_period (int): The delay for an agent to send + a heartbeat to the dispatcher. + node_cert_expiry (int): Automatic expiry for nodes certificates. + external_cas (:py:class:`list`): Configuration for forwarding + signing requests to an external certificate authority. Use + a list of :py:class:`docker.types.SwarmExternalCA`. + name (string): Swarm's name + labels (dict): User-defined key/value metadata. + signing_ca_cert (str): The desired signing CA certificate for all + swarm node TLS leaf certificates, in PEM format. + signing_ca_key (str): The desired signing CA key for all swarm + node TLS leaf certificates, in PEM format. + ca_force_rotate (int): An integer whose purpose is to force swarm + to generate a new signing CA certificate and key, if none have + been specified. + autolock_managers (boolean): If set, generate a key and use it to + lock data stored on the managers. + log_driver (DriverConfig): The default log driver to use for tasks + created in the orchestrator. + + Returns: + :py:class:`docker.types.SwarmSpec` + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> spec = client.api.create_swarm_spec( + snapshot_interval=5000, log_entries_for_slow_followers=1200 + ) + >>> client.api.init_swarm( + advertise_addr='eth0', listen_addr='0.0.0.0:5000', + force_new_cluster=False, swarm_spec=spec + ) + """ + ... + @utils.minimum_version("1.24") + def get_unlock_key(self): + """Get the unlock key for this Swarm manager. + + Returns: + A ``dict`` containing an ``UnlockKey`` member + """ + ... + @utils.minimum_version("1.24") + def init_swarm( + self, + advertise_addr=..., + listen_addr=..., + force_new_cluster=..., + swarm_spec=..., + default_addr_pool=..., + subnet_size=..., + data_path_addr=..., + data_path_port=..., + ): + """Initialize a new Swarm using the current connected engine as the first + node. + + Args: + advertise_addr (string): Externally reachable address advertised + to other nodes. This can either be an address/port combination + in the form ``192.168.1.1:4567``, or an interface followed by a + port number, like ``eth0:4567``. If the port number is omitted, + the port number from the listen address is used. If + ``advertise_addr`` is not specified, it will be automatically + detected when possible. Default: None + listen_addr (string): Listen address used for inter-manager + communication, as well as determining the networking interface + used for the VXLAN Tunnel Endpoint (VTEP). This can either be + an address/port combination in the form ``192.168.1.1:4567``, + or an interface followed by a port number, like ``eth0:4567``. + If the port number is omitted, the default swarm listening port + is used. Default: '0.0.0.0:2377' + force_new_cluster (bool): Force creating a new Swarm, even if + already part of one. Default: False + swarm_spec (dict): Configuration settings of the new Swarm. Use + ``APIClient.create_swarm_spec`` to generate a valid + configuration. Default: None + default_addr_pool (list of strings): Default Address Pool specifies + default subnet pools for global scope networks. Each pool + should be specified as a CIDR block, like '10.0.0.0/8'. + Default: None + subnet_size (int): SubnetSize specifies the subnet size of the + networks created from the default subnet pool. Default: None + data_path_addr (string): Address or interface to use for data path + traffic. For example, 192.168.1.1, or an interface, like eth0. + data_path_port (int): Port number to use for data path traffic. + Acceptable port range is 1024 to 49151. If set to ``None`` or + 0, the default port 4789 will be used. Default: None + + Returns: + (str): The ID of the created node. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.minimum_version("1.24") + def inspect_swarm(self): + """Retrieve low-level information about the current swarm. + + Returns: + A dictionary containing data about the swarm. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("node_id") + @utils.minimum_version("1.24") + def inspect_node(self, node_id): + """Retrieve low-level information about a swarm node. + + Args: + node_id (string): ID of the node to be inspected. + + Returns: + A dictionary containing data about this node. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.minimum_version("1.24") + def join_swarm( + self, remote_addrs, join_token, listen_addr=..., advertise_addr=..., data_path_addr=... + ): # -> Literal[True]: + """Make this Engine join a swarm that has already been created. + + Args: + remote_addrs (:py:class:`list`): Addresses of one or more manager + nodes already participating in the Swarm to join. + join_token (string): Secret token for joining this Swarm. + listen_addr (string): Listen address used for inter-manager + communication if the node gets promoted to manager, as well as + determining the networking interface used for the VXLAN Tunnel + Endpoint (VTEP). Default: ``'0.0.0.0:2377`` + advertise_addr (string): Externally reachable address advertised + to other nodes. This can either be an address/port combination + in the form ``192.168.1.1:4567``, or an interface followed by a + port number, like ``eth0:4567``. If the port number is omitted, + the port number from the listen address is used. If + AdvertiseAddr is not specified, it will be automatically + detected when possible. Default: ``None`` + data_path_addr (string): Address or interface to use for data path + traffic. For example, 192.168.1.1, or an interface, like eth0. + + Returns: + ``True`` if the request went through. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.minimum_version("1.24") + def leave_swarm(self, force=...): # -> Literal[True]: + """Leave a swarm. + + Args: + force (bool): Leave the swarm even if this node is a manager. + Default: ``False`` + + Returns: + ``True`` if the request went through. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.minimum_version("1.24") + def nodes(self, filters=...): + """List swarm nodes. + + Args: + filters (dict): Filters to process on the nodes list. Valid + filters: ``id``, ``name``, ``membership`` and ``role``. + Default: ``None`` + + Returns: + A list of dictionaries containing data about each swarm node. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @utils.check_resource("node_id") + @utils.minimum_version("1.24") + def remove_node(self, node_id, force=...): # -> Literal[True]: + """Remove a node from the swarm. + + Args: + node_id (string): ID of the node to be removed. + force (bool): Force remove an active node. Default: `False` + + Raises: + :py:class:`docker.errors.NotFound` + If the node referenced doesn't exist in the swarm. + + :py:class:`docker.errors.APIError` + If the server returns an error. + + Returns: + `True` if the request was successful. + """ + ... + @utils.minimum_version("1.24") + def unlock_swarm(self, key): # -> Literal[True]: + """Unlock a locked swarm. + + Args: + key (string): The unlock key as provided by + :py:meth:`get_unlock_key` + + Raises: + :py:class:`docker.errors.InvalidArgument` + If the key argument is in an incompatible format + + :py:class:`docker.errors.APIError` + If the server returns an error. + + Returns: + `True` if the request was successful. + + Example: + >>> key = client.api.get_unlock_key() + >>> client.unlock_swarm(key) + + """ + ... + @utils.minimum_version("1.24") + def update_node(self, node_id, version, node_spec=...): # -> Literal[True]: + """Update the node's configuration. + + Args: + node_id (string): ID of the node to be updated. + version (int): The version number of the node object being + updated. This is required to avoid conflicting writes. + node_spec (dict): Configuration settings to update. Any values + not provided will be removed. Default: ``None`` + + Returns: + `True` if the request went through. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> node_spec = {'Availability': 'active', + 'Name': 'node-name', + 'Role': 'manager', + 'Labels': {'foo': 'bar'} + } + >>> client.api.update_node(node_id='24ifsmvkjbyhk', version=8, + node_spec=node_spec) + + """ + ... + @utils.minimum_version("1.24") + def update_swarm( + self, version, swarm_spec=..., rotate_worker_token=..., rotate_manager_token=..., rotate_manager_unlock_key=... + ): # -> Literal[True]: + """Update the Swarm's configuration. + + Args: + version (int): The version number of the swarm object being + updated. This is required to avoid conflicting writes. + swarm_spec (dict): Configuration settings to update. Use + :py:meth:`~docker.api.swarm.SwarmApiMixin.create_swarm_spec` to + generate a valid configuration. Default: ``None``. + rotate_worker_token (bool): Rotate the worker join token. Default: + ``False``. + rotate_manager_token (bool): Rotate the manager join token. + Default: ``False``. + rotate_manager_unlock_key (bool): Rotate the manager unlock key. + Default: ``False``. + + Returns: + ``True`` if the request went through. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... diff --git a/typings/docker/api/volume.pyi b/typings/docker/api/volume.pyi new file mode 100644 index 00000000..deb06617 --- /dev/null +++ b/typings/docker/api/volume.pyi @@ -0,0 +1,112 @@ +"""This type stub file was generated by pyright.""" + +from .. import utils + +class VolumeApiMixin: + def volumes(self, filters=...): + """List volumes currently registered by the docker daemon. Similar to the + ``docker volume ls`` command. + + Args: + filters (dict): Server-side list filtering options. + + Returns: + (dict): Dictionary with list of volume objects as value of the + ``Volumes`` key. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> client.api.volumes() + {u'Volumes': [{u'Driver': u'local', + u'Mountpoint': u'/var/lib/docker/volumes/foobar/_data', + u'Name': u'foobar'}, + {u'Driver': u'local', + u'Mountpoint': u'/var/lib/docker/volumes/baz/_data', + u'Name': u'baz'}]} + """ + ... + def create_volume(self, name=..., driver=..., driver_opts=..., labels=...): + """Create and register a named volume. + + Args: + name (str): Name of the volume + driver (str): Name of the driver used to create the volume + driver_opts (dict): Driver options as a key-value dictionary + labels (dict): Labels to set on the volume + + Returns: + (dict): The created volume reference object + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> volume = client.api.create_volume( + ... name='foobar', + ... driver='local', + ... driver_opts={'foo': 'bar', 'baz': 'false'}, + ... labels={"key": "value"}, + ... ) + ... print(volume) + {u'Driver': u'local', + u'Labels': {u'key': u'value'}, + u'Mountpoint': u'/var/lib/docker/volumes/foobar/_data', + u'Name': u'foobar', + u'Scope': u'local'} + + """ + ... + def inspect_volume(self, name): + """Retrieve volume info by name. + + Args: + name (str): volume name + + Returns: + (dict): Volume information dictionary + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> client.api.inspect_volume('foobar') + {u'Driver': u'local', + u'Mountpoint': u'/var/lib/docker/volumes/foobar/_data', + u'Name': u'foobar'} + + """ + ... + @utils.minimum_version("1.25") + def prune_volumes(self, filters=...): + """Delete unused volumes. + + Args: + filters (dict): Filters to process on the prune list. + + Returns: + (dict): A dict containing a list of deleted volume names and + the amount of disk space reclaimed in bytes. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def remove_volume(self, name, force=...): # -> None: + """Remove a volume. Similar to the ``docker volume rm`` command. + + Args: + name (str): The volume's name + force (bool): Force removal of volumes that were already removed + out of band by the volume driver plugin. + + Raises: + :py:class:`docker.errors.APIError` + If volume failed to remove. + """ + ... diff --git a/typings/docker/client.pyi b/typings/docker/client.pyi new file mode 100644 index 00000000..3f8df31e --- /dev/null +++ b/typings/docker/client.pyi @@ -0,0 +1,83 @@ +from typing import Any +from typing import Self + +from .models.containers import ContainerCollection + +class DockerClient: + """A client for communicating with a Docker server. + + Example: + >>> import docker + >>> client = docker.DockerClient(base_url='unix://var/run/docker.sock') + + Args: + base_url (str): URL to the Docker server. For example, + ``unix:///var/run/docker.sock`` or ``tcp://127.0.0.1:1234``. + version (str): The version of the API to use. Set to ``auto`` to + automatically detect the server's version. Default: ``1.35`` + timeout (int): Default timeout for API calls, in seconds. + tls (bool or :py:class:`~docker.tls.TLSConfig`): Enable TLS. Pass + ``True`` to enable it with default options, or pass a + :py:class:`~docker.tls.TLSConfig` object to use custom + configuration. + user_agent (str): Set a custom user agent for requests to the server. + credstore_env (dict): Override environment variables when calling the + credential store process. + use_ssh_client (bool): If set to `True`, an ssh connection is made + via shelling out to the ssh client. Ensure the ssh client is + installed and configured on the host. + max_pool_size (int): The maximum number of connections + to save in the pool. + """ + + @classmethod + def from_env(cls, **kwargs: Any) -> Self: + """Return a client configured from environment variables. + + The environment variables used are the same as those used by the + Docker command-line client. They are: + + .. envvar:: DOCKER_HOST + + The URL to the Docker host. + + .. envvar:: DOCKER_TLS_VERIFY + + Verify the host against a CA certificate. + + .. envvar:: DOCKER_CERT_PATH + + A path to a directory containing TLS certificates to use when + connecting to the Docker host. + + Args: + version (str): The version of the API to use. Set to ``auto`` to + automatically detect the server's version. Default: ``auto`` + timeout (int): Default timeout for API calls, in seconds. + max_pool_size (int): The maximum number of connections + to save in the pool. + ssl_version (int): A valid `SSL version`_. + assert_hostname (bool): Verify the hostname of the server. + environment (dict): The environment to read environment variables + from. Default: the value of ``os.environ`` + credstore_env (dict): Override environment variables when calling + the credential store process. + use_ssh_client (bool): If set to `True`, an ssh connection is + made via shelling out to the ssh client. Ensure the ssh + client is installed and configured on the host. + + Example: + >>> import docker + >>> client = docker.from_env() + + .. _`SSL version`: + https://docs.python.org/3.5/library/ssl.html#ssl.PROTOCOL_TLSv1 + """ + @property + def containers(self) -> ContainerCollection: + """An object for managing containers on the server. See the + :doc:`containers documentation ` for full details. + """ + ... + +def from_env(**attrs: Any) -> DockerClient: ... diff --git a/typings/docker/constants.pyi b/typings/docker/constants.pyi new file mode 100644 index 00000000..23f2ab92 --- /dev/null +++ b/typings/docker/constants.pyi @@ -0,0 +1,19 @@ +DEFAULT_DOCKER_API_VERSION = ... +MINIMUM_DOCKER_API_VERSION = ... +DEFAULT_TIMEOUT_SECONDS = ... +STREAM_HEADER_SIZE_BYTES = ... +CONTAINER_LIMITS_KEYS = ... +DEFAULT_HTTP_HOST = ... +DEFAULT_UNIX_SOCKET = ... +DEFAULT_NPIPE = ... +BYTE_UNITS = ... +INSECURE_REGISTRY_DEPRECATION_WARNING = ... +IS_WINDOWS_PLATFORM = ... +WINDOWS_LONGPATH_PREFIX = ... +DEFAULT_USER_AGENT = ... +DEFAULT_NUM_POOLS = ... +DEFAULT_NUM_POOLS_SSH = ... +DEFAULT_MAX_POOL_SIZE = ... +DEFAULT_DATA_CHUNK_SIZE = ... +DEFAULT_SWARM_ADDR_POOL = ... +DEFAULT_SWARM_SUBNET_SIZE = ... diff --git a/typings/docker/context/__init__.pyi b/typings/docker/context/__init__.pyi new file mode 100644 index 00000000..c9f93959 --- /dev/null +++ b/typings/docker/context/__init__.pyi @@ -0,0 +1,2 @@ +from .api import ContextAPI as ContextAPI +from .context import Context as Context diff --git a/typings/docker/context/api.pyi b/typings/docker/context/api.pyi new file mode 100644 index 00000000..104555ca --- /dev/null +++ b/typings/docker/context/api.pyi @@ -0,0 +1,129 @@ +"""This type stub file was generated by pyright.""" + +class ContextAPI: + """Context API. + Contains methods for context management: + create, list, remove, get, inspect. + """ + + DEFAULT_CONTEXT = ... + @classmethod + def create_context( + cls, name, orchestrator=..., host=..., tls_cfg=..., default_namespace=..., skip_tls_verify=... + ): # -> Context: + """Creates a new context. + + Returns: + (Context): a Context object. + + Raises: + :py:class:`docker.errors.MissingContextParameter` + If a context name is not provided. + :py:class:`docker.errors.ContextAlreadyExists` + If a context with the name already exists. + :py:class:`docker.errors.ContextException` + If name is default. + + Example: + >>> from docker.context import ContextAPI + >>> ctx = ContextAPI.create_context(name='test') + >>> print(ctx.Metadata) + { + "Name": "test", + "Metadata": {}, + "Endpoints": { + "docker": { + "Host": "unix:///var/run/docker.sock", + "SkipTLSVerify": false + } + } + } + """ + ... + @classmethod + def get_context(cls, name=...): # -> Context | None: + """Retrieves a context object. + + Args: + name (str): The name of the context. + + Example: + >>> from docker.context import ContextAPI + >>> ctx = ContextAPI.get_context(name='test') + >>> print(ctx.Metadata) + { + "Name": "test", + "Metadata": {}, + "Endpoints": { + "docker": { + "Host": "unix:///var/run/docker.sock", + "SkipTLSVerify": false + } + } + } + """ + ... + @classmethod + def contexts(cls): # -> list[Context]: + """Context list. + + Returns: + (Context): List of context objects. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @classmethod + def get_current_context(cls): # -> Context | None: + """Get current context. + + Returns: + (Context): current context object. + """ + ... + @classmethod + def set_current_context(cls, name=...): ... + @classmethod + def remove_context(cls, name): # -> None: + """Remove a context. Similar to the ``docker context rm`` command. + + Args: + name (str): The name of the context + + Raises: + :py:class:`docker.errors.MissingContextParameter` + If a context name is not provided. + :py:class:`docker.errors.ContextNotFound` + If a context with the name does not exist. + :py:class:`docker.errors.ContextException` + If name is default. + + Example: + >>> from docker.context import ContextAPI + >>> ContextAPI.remove_context(name='test') + >>> + """ + ... + @classmethod + def inspect_context( + cls, name=... + ): # -> dict[str, str | dict[str, Unknown] | dict[Unknown | str, dict[str, bytes | Unknown | str | bool]]] | dict[str, Unknown | dict[str, Unknown] | dict[Unknown | str, dict[str, bytes | Unknown | str | bool]]]: + """Remove a context. Similar to the ``docker context inspect`` command. + + Args: + name (str): The name of the context + + Raises: + :py:class:`docker.errors.MissingContextParameter` + If a context name is not provided. + :py:class:`docker.errors.ContextNotFound` + If a context with the name does not exist. + + Example: + >>> from docker.context import ContextAPI + >>> ContextAPI.remove_context(name='test') + >>> + """ + ... diff --git a/typings/docker/context/config.pyi b/typings/docker/context/config.pyi new file mode 100644 index 00000000..f45def17 --- /dev/null +++ b/typings/docker/context/config.pyi @@ -0,0 +1,12 @@ +"""This type stub file was generated by pyright.""" + +METAFILE = ... + +def get_current_context_name(): ... +def write_context_name_to_docker_config(name=...): ... +def get_context_id(name): ... +def get_context_dir(): ... +def get_meta_dir(name=...): ... +def get_meta_file(name): ... +def get_tls_dir(name=..., endpoint=...): ... +def get_context_host(path=..., tls=...): ... diff --git a/typings/docker/context/context.pyi b/typings/docker/context/context.pyi new file mode 100644 index 00000000..dc57fc7d --- /dev/null +++ b/typings/docker/context/context.pyi @@ -0,0 +1,29 @@ +"""This type stub file was generated by pyright.""" + +class Context: + """A context.""" + + def __init__(self, name, orchestrator=..., host=..., endpoints=..., tls=...) -> None: ... + def set_endpoint(self, name=..., host=..., tls_cfg=..., skip_tls_verify=..., def_namespace=...): ... + def inspect(self): ... + @classmethod + def load_context(cls, name): ... + def save(self): ... + def remove(self): ... + def __repr__(self): ... + def __call__(self): ... + def is_docker_host(self): ... + @property + def Name(self): ... + @property + def Host(self): ... + @property + def Orchestrator(self): ... + @property + def Metadata(self): ... + @property + def TLSConfig(self): ... + @property + def TLSMaterial(self): ... + @property + def Storage(self): ... diff --git a/typings/docker/credentials/__init__.pyi b/typings/docker/credentials/__init__.pyi new file mode 100644 index 00000000..dd061032 --- /dev/null +++ b/typings/docker/credentials/__init__.pyi @@ -0,0 +1,7 @@ +from .constants import DEFAULT_LINUX_STORE as DEFAULT_LINUX_STORE +from .constants import DEFAULT_OSX_STORE as DEFAULT_OSX_STORE +from .constants import DEFAULT_WIN32_STORE as DEFAULT_WIN32_STORE +from .constants import PROGRAM_PREFIX as PROGRAM_PREFIX +from .errors import CredentialsNotFound as CredentialsNotFound +from .errors import StoreError as StoreError +from .store import Store as Store diff --git a/typings/docker/credentials/constants.pyi b/typings/docker/credentials/constants.pyi new file mode 100644 index 00000000..19868aad --- /dev/null +++ b/typings/docker/credentials/constants.pyi @@ -0,0 +1,4 @@ +PROGRAM_PREFIX = ... +DEFAULT_LINUX_STORE = ... +DEFAULT_OSX_STORE = ... +DEFAULT_WIN32_STORE = ... diff --git a/typings/docker/credentials/errors.pyi b/typings/docker/credentials/errors.pyi new file mode 100644 index 00000000..0f27a41c --- /dev/null +++ b/typings/docker/credentials/errors.pyi @@ -0,0 +1,7 @@ +"""This type stub file was generated by pyright.""" + +class StoreError(RuntimeError): ... +class CredentialsNotFound(StoreError): ... +class InitializationError(StoreError): ... + +def process_store_error(cpe, program): ... diff --git a/typings/docker/credentials/store.pyi b/typings/docker/credentials/store.pyi new file mode 100644 index 00000000..694642ad --- /dev/null +++ b/typings/docker/credentials/store.pyi @@ -0,0 +1,27 @@ +"""This type stub file was generated by pyright.""" + +class Store: + def __init__(self, program, environment=...) -> None: + """Create a store object that acts as an interface to + perform the basic operations for storing, retrieving + and erasing credentials using `program`. + """ + ... + def get(self, server): # -> Any: + """Retrieve credentials for `server`. If no credentials are found, + a `StoreError` will be raised. + """ + ... + def store(self, server, username, secret): # -> bytes: + """Store credentials for `server`. Raises a `StoreError` if an error + occurs. + """ + ... + def erase(self, server): # -> None: + """Erase credentials for `server`. Raises a `StoreError` if an error + occurs. + """ + ... + def list(self): # -> Any: + """List stored credentials. Requires v0.4.0+ of the helper.""" + ... diff --git a/typings/docker/credentials/utils.pyi b/typings/docker/credentials/utils.pyi new file mode 100644 index 00000000..6c2f546b --- /dev/null +++ b/typings/docker/credentials/utils.pyi @@ -0,0 +1,5 @@ +"""This type stub file was generated by pyright.""" + +def create_environment_dict(overrides): # -> dict[str, str]: + """Create and return a copy of os.environ with the specified overrides.""" + ... diff --git a/typings/docker/errors.pyi b/typings/docker/errors.pyi new file mode 100644 index 00000000..e2fa5d7c --- /dev/null +++ b/typings/docker/errors.pyi @@ -0,0 +1,62 @@ +import requests + +class DockerException(Exception): + """A base class from which all other exceptions inherit. + + If you want to catch all errors that the Docker SDK might raise, + catch this base exception. + """ + +def create_api_error_from_http_exception(e): + """Create a suitable APIError from requests.exceptions.HTTPError.""" + ... + +class APIError(requests.exceptions.HTTPError, DockerException): + """An HTTP error from the API.""" + + def __init__(self, message, response=..., explanation=...) -> None: ... + @property + def status_code(self): ... + def is_error(self): ... + def is_client_error(self): ... + def is_server_error(self): ... + +class NotFound(APIError): ... +class ImageNotFound(NotFound): ... +class InvalidVersion(DockerException): ... +class InvalidRepository(DockerException): ... +class InvalidConfigFile(DockerException): ... +class InvalidArgument(DockerException): ... +class DeprecatedMethod(DockerException): ... + +class TLSParameterError(DockerException): + def __init__(self, msg) -> None: ... + +class NullResource(DockerException, ValueError): ... + +class ContainerError(DockerException): + """Represents a container that has exited with a non-zero exit code.""" + + def __init__(self, container, exit_status, command, image, stderr) -> None: ... + +class StreamParseError(RuntimeError): + def __init__(self, reason) -> None: ... + +class BuildError(DockerException): + def __init__(self, reason, build_log) -> None: ... + +class ImageLoadError(DockerException): ... + +def create_unexpected_kwargs_error(name, kwargs): ... + +class MissingContextParameter(DockerException): + def __init__(self, param) -> None: ... + +class ContextAlreadyExists(DockerException): + def __init__(self, name) -> None: ... + +class ContextException(DockerException): + def __init__(self, msg) -> None: ... + +class ContextNotFound(DockerException): + def __init__(self, name) -> None: ... diff --git a/typings/docker/models/__init__.pyi b/typings/docker/models/__init__.pyi new file mode 100644 index 00000000..4413061e --- /dev/null +++ b/typings/docker/models/__init__.pyi @@ -0,0 +1 @@ +"""This type stub file was generated by pyright.""" diff --git a/typings/docker/models/configs.pyi b/typings/docker/models/configs.pyi new file mode 100644 index 00000000..d894ddfa --- /dev/null +++ b/typings/docker/models/configs.pyi @@ -0,0 +1,56 @@ +"""This type stub file was generated by pyright.""" + +from .resource import Collection +from .resource import Model + +class Config(Model): + """A config.""" + + id_attribute = ... + def __repr__(self): ... + @property + def name(self): ... + def remove(self): + """Remove this config. + + Raises: + :py:class:`docker.errors.APIError` + If config failed to remove. + """ + ... + +class ConfigCollection(Collection): + """Configs on the Docker server.""" + + model = Config + def create(self, **kwargs): ... + def get(self, config_id): # -> Model: + """Get a config. + + Args: + config_id (str): Config ID. + + Returns: + (:py:class:`Config`): The config. + + Raises: + :py:class:`docker.errors.NotFound` + If the config does not exist. + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def list(self, **kwargs): # -> list[Model]: + """List configs. Similar to the ``docker config ls`` command. + + Args: + filters (dict): Server-side list filtering options. + + Returns: + (list of :py:class:`Config`): The configs. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... diff --git a/typings/docker/models/containers.pyi b/typings/docker/models/containers.pyi new file mode 100644 index 00000000..dcea6996 --- /dev/null +++ b/typings/docker/models/containers.pyi @@ -0,0 +1,832 @@ +from typing import Any +from typing import Iterator +from typing import Literal +from typing import overload + +from .resource import Collection +from .resource import Model + +class Container(Model): + """Local representation of a container object. Detailed configuration may + be accessed through the :py:attr:`attrs` attribute. Note that local + attributes are cached; users may call :py:meth:`reload` to + query the Docker daemon for the current properties, causing + :py:attr:`attrs` to be refreshed. + """ + + @property + def name(self) -> str: + """The name of the container.""" + @property + def image(self): # -> None: + """The image of the container.""" + @property + def labels(self) -> dict[Any, Any]: + """The labels of a container as dictionary.""" + @property + def status(self) -> Literal["created", "restarting", "running", "removing", "paused", "exited"]: + """The status of the container. For example, ``running``, or ``exited``.""" + @property + def ports(self) -> dict[Any, Any]: + """The ports that the container exposes as a dictionary.""" + def attach(self, **kwargs): + """Attach to this container. + + :py:meth:`logs` is a wrapper around this method, which you can + use instead if you want to fetch/stream container output without first + retrieving the entire backlog. + + Args: + stdout (bool): Include stdout. + stderr (bool): Include stderr. + stream (bool): Return container output progressively as an iterator + of strings, rather than a single string. + logs (bool): Include the container's previous output. + + Returns: + By default, the container's output as a single string. + + If ``stream=True``, an iterator of output strings. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + def attach_socket(self, **kwargs): + """Like :py:meth:`attach`, but returns the underlying socket-like object + for the HTTP request. + + Args: + params (dict): Dictionary of request parameters (e.g. ``stdout``, + ``stderr``, ``stream``). + ws (bool): Use websockets instead of raw HTTP. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def commit(self, repository=..., tag=..., **kwargs): + """Commit a container to an image. Similar to the ``docker commit`` + command. + + Args: + repository (str): The repository to push the image to + tag (str): The tag to push + message (str): A commit message + author (str): The name of the author + changes (str): Dockerfile instructions to apply while committing + conf (dict): The configuration for the container. See the + `Engine API documentation + `_ + for full details. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def diff(self): + """Inspect changes on a container's filesystem. + + Returns: + (list) A list of dictionaries containing the attributes `Path` + and `Kind`. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def exec_run( + self, + cmd, + stdout=..., + stderr=..., + stdin=..., + tty=..., + privileged=..., + user=..., + detach=..., + stream=..., + socket=..., + environment=..., + workdir=..., + demux=..., + ): # -> ExecResult: + """Run a command inside this container. Similar to + ``docker exec``. + + Args: + cmd (str or list): Command to be executed + stdout (bool): Attach to stdout. Default: ``True`` + stderr (bool): Attach to stderr. Default: ``True`` + stdin (bool): Attach to stdin. Default: ``False`` + tty (bool): Allocate a pseudo-TTY. Default: False + privileged (bool): Run as privileged. + user (str): User to execute command as. Default: root + detach (bool): If true, detach from the exec command. + Default: False + stream (bool): Stream response data. Default: False + socket (bool): Return the connection socket to allow custom + read/write operations. Default: False + environment (dict or list): A dictionary or a list of strings in + the following format ``["PASSWORD=xxx"]`` or + ``{"PASSWORD": "xxx"}``. + workdir (str): Path to working directory for this exec session + demux (bool): Return stdout and stderr separately + + Returns: + (ExecResult): A tuple of (exit_code, output) + exit_code: (int): + Exit code for the executed command or ``None`` if + either ``stream`` or ``socket`` is ``True``. + output: (generator, bytes, or tuple): + If ``stream=True``, a generator yielding response chunks. + If ``socket=True``, a socket object for the connection. + If ``demux=True``, a tuple of two bytes: stdout and stderr. + A bytestring containing response data otherwise. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def export(self, chunk_size=...): + """Export the contents of the container's filesystem as a tar archive. + + Args: + chunk_size (int): The number of bytes returned by each iteration + of the generator. If ``None``, data will be streamed as it is + received. Default: 2 MB + + Returns: + (str): The filesystem tar archive + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def get_archive(self, path, chunk_size=..., encode_stream=...): + """Retrieve a file or folder from the container in the form of a tar + archive. + + Args: + path (str): Path to the file or folder to retrieve + chunk_size (int): The number of bytes returned by each iteration + of the generator. If ``None``, data will be streamed as it is + received. Default: 2 MB + encode_stream (bool): Determines if data should be encoded + (gzip-compressed) during transmission. Default: False + + Returns: + (tuple): First element is a raw tar data stream. Second element is + a dict containing ``stat`` information on the specified ``path``. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> f = open('./sh_bin.tar', 'wb') + >>> bits, stat = container.get_archive('/bin/sh') + >>> print(stat) + {'name': 'sh', 'size': 1075464, 'mode': 493, + 'mtime': '2018-10-01T15:37:48-07:00', 'linkTarget': ''} + >>> for chunk in bits: + ... f.write(chunk) + >>> f.close() + """ + ... + def kill(self, signal: str | int = ...) -> None: + """Kill or send a signal to the container. + + Args: + signal (str or int): The signal to send. Defaults to ``SIGKILL`` + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + @overload + def logs(self, stream: Literal[False] = False, **kwargs: Any) -> bytes: ... + @overload + def logs(self, stream: Literal[True] = ..., **kwargs: Any) -> Iterator[bytes]: + """Get logs from this container. Similar to the ``docker logs`` command. + + The ``stream`` parameter makes the ``logs`` function return a blocking + generator you can iterate over to retrieve log output as it happens. + + Args: + stdout (bool): Get ``STDOUT``. Default ``True`` + stderr (bool): Get ``STDERR``. Default ``True`` + stream (bool): Stream the response. Default ``False`` + timestamps (bool): Show timestamps. Default ``False`` + tail (str or int): Output specified number of lines at the end of + logs. Either an integer of number of lines or the string + ``all``. Default ``all`` + since (datetime, int, or float): Show logs since a given datetime, + integer epoch (in seconds) or float (in nanoseconds) + follow (bool): Follow log output. Default ``False`` + until (datetime, int, or float): Show logs that occurred before + the given datetime, integer epoch (in seconds), or + float (in nanoseconds) + + Returns: + (generator or str): Logs from the container. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def pause(self): + """Pauses all processes within this container. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def put_archive(self, path, data): + """Insert a file or folder in this container using a tar archive as + source. + + Args: + path (str): Path inside the container where the file(s) will be + extracted. Must exist. + data (bytes or stream): tar data to be extracted + + Returns: + (bool): True if the call succeeds. + + Raises: + :py:class:`~docker.errors.APIError` If an error occurs. + """ + ... + def remove(self, v: bool | None = ..., link: bool | None = ..., force: bool | None = ...) -> None: + """Remove this container. Similar to the ``docker rm`` command. + + Args: + v (bool): Remove the volumes associated with the container + link (bool): Remove the specified link and not the underlying + container + force (bool): Force the removal of a running container (uses + ``SIGKILL``) + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def rename(self, name): + """Rename this container. Similar to the ``docker rename`` command. + + Args: + name (str): New name for the container + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def resize(self, height, width): + """Resize the tty session. + + Args: + height (int): Height of tty session + width (int): Width of tty session + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def restart(self, **kwargs): + """Restart this container. Similar to the ``docker restart`` command. + + Args: + timeout (int): Number of seconds to try to stop for before killing + the container. Once killed it will then be restarted. Default + is 10 seconds. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def start(self, **kwargs): + """Start this container. Similar to the ``docker start`` command, but + doesn't support attach options. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def stats(self, **kwargs): + """Stream statistics for this container. Similar to the + ``docker stats`` command. + + Args: + decode (bool): If set to true, stream will be decoded into dicts + on the fly. Only applicable if ``stream`` is True. + False by default. + stream (bool): If set to false, only the current stats will be + returned instead of a stream. True by default. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def stop(self, **kwargs: Any) -> None: + """Stops a container. Similar to the ``docker stop`` command. + + Args: + timeout (int): Timeout in seconds to wait for the container to + stop before sending a ``SIGKILL``. Default: 10 + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def top(self, **kwargs: Any) -> str: + """Display the running processes of the container. + + Args: + ps_args (str): An optional arguments passed to ps (e.g. ``aux``) + + Returns: + (str): The output of the top + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def unpause(self): + """Unpause all processes within the container. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def update(self, **kwargs): + """Update resource configuration of the containers. + + Args: + blkio_weight (int): Block IO (relative weight), between 10 and 1000 + cpu_period (int): Limit CPU CFS (Completely Fair Scheduler) period + cpu_quota (int): Limit CPU CFS (Completely Fair Scheduler) quota + cpu_shares (int): CPU shares (relative weight) + cpuset_cpus (str): CPUs in which to allow execution + cpuset_mems (str): MEMs in which to allow execution + mem_limit (int or str): Memory limit + mem_reservation (int or str): Memory soft limit + memswap_limit (int or str): Total memory (memory + swap), -1 to + disable swap + kernel_memory (int or str): Kernel memory limit + restart_policy (dict): Restart policy dictionary + + Returns: + (dict): Dictionary containing a ``Warnings`` key. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def wait(self, **kwargs: Any) -> dict[Any, Any]: + """Block until the container stops, then return its exit code. Similar to + the ``docker wait`` command. + + Args: + timeout (int): Request timeout + condition (str): Wait until a container state reaches the given + condition, either ``not-running`` (default), ``next-exit``, + or ``removed`` + + Returns: + (dict): The API's response as a Python dictionary, including + the container's exit code under the ``StatusCode`` attribute. + + Raises: + :py:class:`requests.exceptions.ReadTimeout` + If the timeout is exceeded. + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + +class ContainerCollection(Collection): + model = Container + @overload + def run( + self, image: str, command: list[str] | str = ..., detach: Literal[True] = ..., **kwargs: Any + ) -> Container: ... + @overload + def run(self, image: str, command: list[str] | str = ..., detach: Literal[False] = ..., **kwargs: Any) -> bytes: + """Run a container. By default, it will wait for the container to finish + and return its logs, similar to ``docker run``. + + If the ``detach`` argument is ``True``, it will start the container + and immediately return a :py:class:`Container` object, similar to + ``docker run -d``. + + Example: + Run a container and get its output: + + >>> import docker + >>> client = docker.from_env() + >>> client.containers.run('alpine', 'echo hello world') + b'hello world\\n' + + Run a container and detach: + + >>> container = client.containers.run('bfirsh/reticulate-splines', + detach=True) + >>> container.logs() + 'Reticulating spline 1...\\nReticulating spline 2...\\n' + + Args: + image (str): The image to run. + command (str or list): The command to run in the container. + auto_remove (bool): enable auto-removal of the container on daemon + side when the container's process exits. + blkio_weight_device: Block IO weight (relative device weight) in + the form of: ``[{"Path": "device_path", "Weight": weight}]``. + blkio_weight: Block IO weight (relative weight), accepts a weight + value between 10 and 1000. + cap_add (list of str): Add kernel capabilities. For example, + ``["SYS_ADMIN", "MKNOD"]``. + cap_drop (list of str): Drop kernel capabilities. + cgroup_parent (str): Override the default parent cgroup. + cgroupns (str): Override the default cgroup namespace mode for the + container. One of: + - ``private`` the container runs in its own private cgroup + namespace. + - ``host`` use the host system's cgroup namespace. + cpu_count (int): Number of usable CPUs (Windows only). + cpu_percent (int): Usable percentage of the available CPUs + (Windows only). + cpu_period (int): The length of a CPU period in microseconds. + cpu_quota (int): Microseconds of CPU time that the container can + get in a CPU period. + cpu_rt_period (int): Limit CPU real-time period in microseconds. + cpu_rt_runtime (int): Limit CPU real-time runtime in microseconds. + cpu_shares (int): CPU shares (relative weight). + cpuset_cpus (str): CPUs in which to allow execution (``0-3``, + ``0,1``). + cpuset_mems (str): Memory nodes (MEMs) in which to allow execution + (``0-3``, ``0,1``). Only effective on NUMA systems. + detach (bool): Run container in the background and return a + :py:class:`Container` object. + device_cgroup_rules (:py:class:`list`): A list of cgroup rules to + apply to the container. + device_read_bps: Limit read rate (bytes per second) from a device + in the form of: `[{"Path": "device_path", "Rate": rate}]` + device_read_iops: Limit read rate (IO per second) from a device. + device_write_bps: Limit write rate (bytes per second) from a + device. + device_write_iops: Limit write rate (IO per second) from a device. + devices (:py:class:`list`): Expose host devices to the container, + as a list of strings in the form + ``::``. + + For example, ``/dev/sda:/dev/xvda:rwm`` allows the container + to have read-write access to the host's ``/dev/sda`` via a + node named ``/dev/xvda`` inside the container. + device_requests (:py:class:`list`): Expose host resources such as + GPUs to the container, as a list of + :py:class:`docker.types.DeviceRequest` instances. + dns (:py:class:`list`): Set custom DNS servers. + dns_opt (:py:class:`list`): Additional options to be added to the + container's ``resolv.conf`` file. + dns_search (:py:class:`list`): DNS search domains. + domainname (str or list): Set custom DNS search domains. + entrypoint (str or list): The entrypoint for the container. + environment (dict or list): Environment variables to set inside + the container, as a dictionary or a list of strings in the + format ``["SOMEVARIABLE=xxx"]``. + extra_hosts (dict): Additional hostnames to resolve inside the + container, as a mapping of hostname to IP address. + group_add (:py:class:`list`): List of additional group names and/or + IDs that the container process will run as. + healthcheck (dict): Specify a test to perform to check that the + container is healthy. The dict takes the following keys: + + - test (:py:class:`list` or str): Test to perform to determine + container health. Possible values: + + - Empty list: Inherit healthcheck from parent image + - ``["NONE"]``: Disable healthcheck + - ``["CMD", args...]``: exec arguments directly. + - ``["CMD-SHELL", command]``: Run command in the system's + default shell. + + If a string is provided, it will be used as a ``CMD-SHELL`` + command. + - interval (int): The time to wait between checks in + nanoseconds. It should be 0 or at least 1000000 (1 ms). + - timeout (int): The time to wait before considering the check + to have hung. It should be 0 or at least 1000000 (1 ms). + - retries (int): The number of consecutive failures needed to + consider a container as unhealthy. + - start_period (int): Start period for the container to + initialize before starting health-retries countdown in + nanoseconds. It should be 0 or at least 1000000 (1 ms). + hostname (str): Optional hostname for the container. + init (bool): Run an init inside the container that forwards + signals and reaps processes + init_path (str): Path to the docker-init binary + ipc_mode (str): Set the IPC mode for the container. + isolation (str): Isolation technology to use. Default: `None`. + kernel_memory (int or str): Kernel memory limit + labels (dict or list): A dictionary of name-value labels (e.g. + ``{"label1": "value1", "label2": "value2"}``) or a list of + names of labels to set with empty values (e.g. + ``["label1", "label2"]``) + links (dict): Mapping of links using the + ``{'container': 'alias'}`` format. The alias is optional. + Containers declared in this dict will be linked to the new + container using the provided alias. Default: ``None``. + log_config (LogConfig): Logging configuration. + lxc_conf (dict): LXC config. + mac_address (str): MAC address to assign to the container. + mem_limit (int or str): Memory limit. Accepts float values + (which represent the memory limit of the created container in + bytes) or a string with a units identification char + (``100000b``, ``1000k``, ``128m``, ``1g``). If a string is + specified without a units character, bytes are assumed as an + intended unit. + mem_reservation (int or str): Memory soft limit. + mem_swappiness (int): Tune a container's memory swappiness + behavior. Accepts number between 0 and 100. + memswap_limit (str or int): Maximum amount of memory + swap a + container is allowed to consume. + mounts (:py:class:`list`): Specification for mounts to be added to + the container. More powerful alternative to ``volumes``. Each + item in the list is expected to be a + :py:class:`docker.types.Mount` object. + name (str): The name for this container. + nano_cpus (int): CPU quota in units of 1e-9 CPUs. + network (str): Name of the network this container will be connected + to at creation time. You can connect to additional networks + using :py:meth:`Network.connect`. Incompatible with + ``network_mode``. + network_disabled (bool): Disable networking. + network_mode (str): One of: + + - ``bridge`` Create a new network stack for the container on + the bridge network. + - ``none`` No networking for this container. + - ``container:`` Reuse another container's network + stack. + - ``host`` Use the host network stack. + This mode is incompatible with ``ports``. + + Incompatible with ``network``. + network_driver_opt (dict): A dictionary of options to provide + to the network driver. Defaults to ``None``. Used in + conjuction with ``network``. Incompatible + with ``network_mode``. + oom_kill_disable (bool): Whether to disable OOM killer. + oom_score_adj (int): An integer value containing the score given + to the container in order to tune OOM killer preferences. + pid_mode (str): If set to ``host``, use the host PID namespace + inside the container. + pids_limit (int): Tune a container's pids limit. Set ``-1`` for + unlimited. + platform (str): Platform in the format ``os[/arch[/variant]]``. + Only used if the method needs to pull the requested image. + ports (dict): Ports to bind inside the container. + + The keys of the dictionary are the ports to bind inside the + container, either as an integer or a string in the form + ``port/protocol``, where the protocol is either ``tcp``, + ``udp``, or ``sctp``. + + The values of the dictionary are the corresponding ports to + open on the host, which can be either: + + - The port number, as an integer. For example, + ``{'2222/tcp': 3333}`` will expose port 2222 inside the + container as port 3333 on the host. + - ``None``, to assign a random host port. For example, + ``{'2222/tcp': None}``. + - A tuple of ``(address, port)`` if you want to specify the + host interface. For example, + ``{'1111/tcp': ('127.0.0.1', 1111)}``. + - A list of integers, if you want to bind multiple host ports + to a single container port. For example, + ``{'1111/tcp': [1234, 4567]}``. + + Incompatible with ``host`` network mode. + privileged (bool): Give extended privileges to this container. + publish_all_ports (bool): Publish all ports to the host. + read_only (bool): Mount the container's root filesystem as read + only. + remove (bool): Remove the container when it has finished running. + Default: ``False``. + restart_policy (dict): Restart the container when it exits. + Configured as a dictionary with keys: + + - ``Name`` One of ``on-failure``, or ``always``. + - ``MaximumRetryCount`` Number of times to restart the + container on failure. + + For example: + ``{"Name": "on-failure", "MaximumRetryCount": 5}`` + + runtime (str): Runtime to use with this container. + security_opt (:py:class:`list`): A list of string values to + customize labels for MLS systems, such as SELinux. + shm_size (str or int): Size of /dev/shm (e.g. ``1G``). + stdin_open (bool): Keep ``STDIN`` open even if not attached. + stdout (bool): Return logs from ``STDOUT`` when ``detach=False``. + Default: ``True``. + stderr (bool): Return logs from ``STDERR`` when ``detach=False``. + Default: ``False``. + stop_signal (str): The stop signal to use to stop the container + (e.g. ``SIGINT``). + storage_opt (dict): Storage driver options per container as a + key-value mapping. + stream (bool): If true and ``detach`` is false, return a log + generator instead of a string. Ignored if ``detach`` is true. + Default: ``False``. + sysctls (dict): Kernel parameters to set in the container. + tmpfs (dict): Temporary filesystems to mount, as a dictionary + mapping a path inside the container to options for that path. + + For example: + + .. code-block:: python + + { + '/mnt/vol2': '', + '/mnt/vol1': 'size=3G,uid=1000' + } + + tty (bool): Allocate a pseudo-TTY. + ulimits (:py:class:`list`): Ulimits to set inside the container, + as a list of :py:class:`docker.types.Ulimit` instances. + use_config_proxy (bool): If ``True``, and if the docker client + configuration file (``~/.docker/config.json`` by default) + contains a proxy configuration, the corresponding environment + variables will be set in the container being built. + user (str or int): Username or UID to run commands as inside the + container. + userns_mode (str): Sets the user namespace mode for the container + when user namespace remapping option is enabled. Supported + values are: ``host`` + uts_mode (str): Sets the UTS namespace mode for the container. + Supported values are: ``host`` + version (str): The version of the API to use. Set to ``auto`` to + automatically detect the server's version. Default: ``1.35`` + volume_driver (str): The name of a volume driver/plugin. + volumes (dict or list): A dictionary to configure volumes mounted + inside the container. The key is either the host path or a + volume name, and the value is a dictionary with the keys: + + - ``bind`` The path to mount the volume inside the container + - ``mode`` Either ``rw`` to mount the volume read/write, or + ``ro`` to mount it read-only. + + For example: + + .. code-block:: python + + {'/home/user1/': {'bind': '/mnt/vol2', 'mode': 'rw'}, + '/var/www': {'bind': '/mnt/vol1', 'mode': 'ro'}} + + Or a list of strings which each one of its elements specifies a + mount volume. + + For example: + + .. code-block:: python + + ['/home/user1/:/mnt/vol2','/var/www:/mnt/vol1'] + + volumes_from (:py:class:`list`): List of container names or IDs to + get volumes from. + working_dir (str): Path to the working directory. + + Returns: + The container logs, either ``STDOUT``, ``STDERR``, or both, + depending on the value of the ``stdout`` and ``stderr`` arguments. + + ``STDOUT`` and ``STDERR`` may be read only if either ``json-file`` + or ``journald`` logging driver used. Thus, if you are using none of + these drivers, a ``None`` object is returned instead. See the + `Engine API documentation + `_ + for full details. + + If ``detach`` is ``True``, a :py:class:`Container` object is + returned instead. + + Raises: + :py:class:`docker.errors.ContainerError` + If the container exits with a non-zero exit code and + ``detach`` is ``False``. + :py:class:`docker.errors.ImageNotFound` + If the specified image does not exist. + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + def create(self, image, command=..., **kwargs): # -> Model: + """Create a container without starting it. Similar to ``docker create``. + + Takes the same arguments as :py:meth:`run`, except for ``stdout``, + ``stderr``, and ``remove``. + + Returns: + A :py:class:`Container` object. + + Raises: + :py:class:`docker.errors.ImageNotFound` + If the specified image does not exist. + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def get(self, container_id: str) -> Container: + """Get a container by name or ID. + + Args: + container_id (str): Container name or ID. + + Returns: + A :py:class:`Container` object. + + Raises: + :py:class:`docker.errors.NotFound` + If the container does not exist. + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def list( + self, all=..., before=..., filters=..., limit=..., since=..., sparse=..., ignore_removed=... + ): # -> list[Model] | list[Unknown]: + """List containers. Similar to the ``docker ps`` command. + + Args: + all (bool): Show all containers. Only running containers are shown + by default + since (str): Show only containers created since Id or Name, include + non-running ones + before (str): Show only container created before Id or Name, + include non-running ones + limit (int): Show `limit` last created containers, include + non-running ones + filters (dict): Filters to be processed on the image list. + Available filters: + + - `exited` (int): Only containers with specified exit code + - `status` (str): One of ``restarting``, ``running``, + ``paused``, ``exited`` + - `label` (str|list): format either ``"key"``, ``"key=value"`` + or a list of such. + - `id` (str): The id of the container. + - `name` (str): The name of the container. + - `ancestor` (str): Filter by container ancestor. Format of + ``[:tag]``, ````, or + ````. + - `before` (str): Only containers created before a particular + container. Give the container name or id. + - `since` (str): Only containers created after a particular + container. Give container name or id. + + A comprehensive list can be found in the documentation for + `docker ps + `_. + + sparse (bool): Do not inspect containers. Returns partial + information, but guaranteed not to block. Use + :py:meth:`Container.reload` on resulting objects to retrieve + all attributes. Default: ``False`` + ignore_removed (bool): Ignore failures due to missing containers + when attempting to inspect containers from the original list. + Set to ``True`` if race conditions are likely. Has no effect + if ``sparse=True``. Default: ``False`` + + Returns: + (list of :py:class:`Container`) + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def prune(self, filters=...): ... + +RUN_CREATE_KWARGS = ... +RUN_HOST_CONFIG_KWARGS = ... +ExecResult = ... diff --git a/typings/docker/models/images.pyi b/typings/docker/models/images.pyi new file mode 100644 index 00000000..3d636212 --- /dev/null +++ b/typings/docker/models/images.pyi @@ -0,0 +1,329 @@ +"""This type stub file was generated by pyright.""" + +from .resource import Collection +from .resource import Model + +class Image(Model): + """An image on the server.""" + + def __repr__(self): ... + @property + def labels(self): # -> dict[Any, Any]: + """The labels of an image as dictionary.""" + ... + @property + def short_id(self): + """The ID of the image truncated to 12 characters, plus the ``sha256:`` + prefix. + """ + ... + @property + def tags(self): # -> list[Unknown]: + """The image's tags.""" + ... + def history(self): + """Show the history of an image. + + Returns: + (str): The history of the image. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def remove(self, force=..., noprune=...): + """Remove this image. + + Args: + force (bool): Force removal of the image + noprune (bool): Do not delete untagged parents + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def save(self, chunk_size=..., named=...): + """Get a tarball of an image. Similar to the ``docker save`` command. + + Args: + chunk_size (int): The generator will return up to that much data + per iteration, but may return less. If ``None``, data will be + streamed as it is received. Default: 2 MB + named (str or bool): If ``False`` (default), the tarball will not + retain repository and tag information for this image. If set + to ``True``, the first tag in the :py:attr:`~tags` list will + be used to identify the image. Alternatively, any element of + the :py:attr:`~tags` list can be used as an argument to use + that specific tag as the saved identifier. + + Returns: + (generator): A stream of raw archive data. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> image = cli.images.get("busybox:latest") + >>> f = open('/tmp/busybox-latest.tar', 'wb') + >>> for chunk in image.save(): + >>> f.write(chunk) + >>> f.close() + """ + ... + def tag(self, repository, tag=..., **kwargs): + """Tag this image into a repository. Similar to the ``docker tag`` + command. + + Args: + repository (str): The repository to set for the tag + tag (str): The tag name + force (bool): Force + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Returns: + (bool): ``True`` if successful + """ + ... + +class RegistryData(Model): + """Image metadata stored on the registry, including available platforms.""" + + def __init__(self, image_name, *args, **kwargs) -> None: ... + @property + def id(self): + """The ID of the object.""" + ... + @property + def short_id(self): + """The ID of the image truncated to 12 characters, plus the ``sha256:`` + prefix. + """ + ... + def pull(self, platform=...): + """Pull the image digest. + + Args: + platform (str): The platform to pull the image for. + Default: ``None`` + + Returns: + (:py:class:`Image`): A reference to the pulled image. + """ + ... + def has_platform(self, platform): # -> bool: + """Check whether the given platform identifier is available for this + digest. + + Args: + platform (str or dict): A string using the ``os[/arch[/variant]]`` + format, or a platform dictionary. + + Returns: + (bool): ``True`` if the platform is recognized as available, + ``False`` otherwise. + + Raises: + :py:class:`docker.errors.InvalidArgument` + If the platform argument is not a valid descriptor. + """ + ... + def reload(self): ... + +class ImageCollection(Collection): + model = Image + def build(self, **kwargs): # -> Model | tuple[Model, Iterator[Any]]: + """Build an image and return it. Similar to the ``docker build`` + command. Either ``path`` or ``fileobj`` must be set. + + If you already have a tar file for the Docker build context (including + a Dockerfile), pass a readable file-like object to ``fileobj`` + and also pass ``custom_context=True``. If the stream is also + compressed, set ``encoding`` to the correct value (e.g ``gzip``). + + If you want to get the raw output of the build, use the + :py:meth:`~docker.api.build.BuildApiMixin.build` method in the + low-level API. + + Args: + path (str): Path to the directory containing the Dockerfile + fileobj: A file object to use as the Dockerfile. (Or a file-like + object) + tag (str): A tag to add to the final image + quiet (bool): Whether to return the status + nocache (bool): Don't use the cache when set to ``True`` + rm (bool): Remove intermediate containers. The ``docker build`` + command now defaults to ``--rm=true``, but we have kept the old + default of `False` to preserve backward compatibility + timeout (int): HTTP timeout + custom_context (bool): Optional if using ``fileobj`` + encoding (str): The encoding for a stream. Set to ``gzip`` for + compressing + pull (bool): Downloads any updates to the FROM image in Dockerfiles + forcerm (bool): Always remove intermediate containers, even after + unsuccessful builds + dockerfile (str): path within the build context to the Dockerfile + buildargs (dict): A dictionary of build arguments + container_limits (dict): A dictionary of limits applied to each + container created by the build process. Valid keys: + + - memory (int): set memory limit for build + - memswap (int): Total memory (memory + swap), -1 to disable + swap + - cpushares (int): CPU shares (relative weight) + - cpusetcpus (str): CPUs in which to allow execution, e.g., + ``"0-3"``, ``"0,1"`` + shmsize (int): Size of `/dev/shm` in bytes. The size must be + greater than 0. If omitted the system uses 64MB + labels (dict): A dictionary of labels to set on the image + cache_from (list): A list of images used for build cache + resolution + target (str): Name of the build-stage to build in a multi-stage + Dockerfile + network_mode (str): networking mode for the run commands during + build + squash (bool): Squash the resulting images layers into a + single layer. + extra_hosts (dict): Extra hosts to add to /etc/hosts in building + containers, as a mapping of hostname to IP address. + platform (str): Platform in the format ``os[/arch[/variant]]``. + isolation (str): Isolation technology used during build. + Default: `None`. + use_config_proxy (bool): If ``True``, and if the docker client + configuration file (``~/.docker/config.json`` by default) + contains a proxy configuration, the corresponding environment + variables will be set in the container being built. + + Returns: + (tuple): The first item is the :py:class:`Image` object for the + image that was built. The second item is a generator of the + build logs as JSON-decoded objects. + + Raises: + :py:class:`docker.errors.BuildError` + If there is an error during the build. + :py:class:`docker.errors.APIError` + If the server returns any other error. + ``TypeError`` + If neither ``path`` nor ``fileobj`` is specified. + """ + ... + def get(self, name): # -> Model: + """Gets an image. + + Args: + name (str): The name of the image. + + Returns: + (:py:class:`Image`): The image. + + Raises: + :py:class:`docker.errors.ImageNotFound` + If the image does not exist. + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def get_registry_data(self, name, auth_config=...): # -> RegistryData: + """Gets the registry data for an image. + + Args: + name (str): The name of the image. + auth_config (dict): Override the credentials that are found in the + config for this request. ``auth_config`` should contain the + ``username`` and ``password`` keys to be valid. + + Returns: + (:py:class:`RegistryData`): The data object. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def list(self, name=..., all=..., filters=...): # -> list[Model]: + """List images on the server. + + Args: + name (str): Only show images belonging to the repository ``name`` + all (bool): Show intermediate image layers. By default, these are + filtered out. + filters (dict): Filters to be processed on the image list. + Available filters: + - ``dangling`` (bool) + - `label` (str|list): format either ``"key"``, ``"key=value"`` + or a list of such. + + Returns: + (list of :py:class:`Image`): The images. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def load(self, data): # -> list[Model]: + """Load an image that was previously saved using + :py:meth:`~docker.models.images.Image.save` (or ``docker save``). + Similar to ``docker load``. + + Args: + data (binary): Image data to be loaded. + + Returns: + (list of :py:class:`Image`): The images. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def pull(self, repository, tag=..., all_tags=..., **kwargs): # -> Model | list[Model]: + """Pull an image of the given name and return it. Similar to the + ``docker pull`` command. + If ``tag`` is ``None`` or empty, it is set to ``latest``. + If ``all_tags`` is set, the ``tag`` parameter is ignored and all image + tags will be pulled. + + If you want to get the raw pull output, use the + :py:meth:`~docker.api.image.ImageApiMixin.pull` method in the + low-level API. + + Args: + repository (str): The repository to pull + tag (str): The tag to pull + auth_config (dict): Override the credentials that are found in the + config for this request. ``auth_config`` should contain the + ``username`` and ``password`` keys to be valid. + platform (str): Platform in the format ``os[/arch[/variant]]`` + all_tags (bool): Pull all image tags + + Returns: + (:py:class:`Image` or list): The image that has been pulled. + If ``all_tags`` is True, the method will return a list + of :py:class:`Image` objects belonging to this repository. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> # Pull the image tagged `latest` in the busybox repo + >>> image = client.images.pull('busybox') + + >>> # Pull all tags in the busybox repo + >>> images = client.images.pull('busybox', all_tags=True) + """ + ... + def push(self, repository, tag=..., **kwargs): ... + def remove(self, *args, **kwargs): ... + def search(self, *args, **kwargs): ... + def prune(self, filters=...): ... + def prune_builds(self, *args, **kwargs): ... + +def normalize_platform(platform, engine_info): ... diff --git a/typings/docker/models/networks.pyi b/typings/docker/models/networks.pyi new file mode 100644 index 00000000..846efc6e --- /dev/null +++ b/typings/docker/models/networks.pyi @@ -0,0 +1,175 @@ +"""This type stub file was generated by pyright.""" + +from .resource import Collection +from .resource import Model + +class Network(Model): + """A Docker network.""" + + @property + def name(self): # -> None: + """The name of the network.""" + ... + @property + def containers(self): # -> list[Unknown]: + """The containers that are connected to the network, as a list of + :py:class:`~docker.models.containers.Container` objects. + """ + ... + def connect(self, container, *args, **kwargs): + """Connect a container to this network. + + Args: + container (str): Container to connect to this network, as either + an ID, name, or :py:class:`~docker.models.containers.Container` + object. + aliases (:py:class:`list`): A list of aliases for this endpoint. + Names in that list can be used within the network to reach the + container. Defaults to ``None``. + links (:py:class:`list`): A list of links for this endpoint. + Containers declared in this list will be linkedto this + container. Defaults to ``None``. + ipv4_address (str): The IP address of this container on the + network, using the IPv4 protocol. Defaults to ``None``. + ipv6_address (str): The IP address of this container on the + network, using the IPv6 protocol. Defaults to ``None``. + link_local_ips (:py:class:`list`): A list of link-local (IPv4/IPv6) + addresses. + driver_opt (dict): A dictionary of options to provide to the + network driver. Defaults to ``None``. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def disconnect(self, container, *args, **kwargs): + """Disconnect a container from this network. + + Args: + container (str): Container to disconnect from this network, as + either an ID, name, or + :py:class:`~docker.models.containers.Container` object. + force (bool): Force the container to disconnect from a network. + Default: ``False`` + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def remove(self): + """Remove this network. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + +class NetworkCollection(Collection): + """Networks on the Docker server.""" + + model = Network + def create(self, name, *args, **kwargs): # -> Model: + """Create a network. Similar to the ``docker network create``. + + Args: + name (str): Name of the network + driver (str): Name of the driver used to create the network + options (dict): Driver options as a key-value dictionary + ipam (IPAMConfig): Optional custom IP scheme for the network. + check_duplicate (bool): Request daemon to check for networks with + same name. Default: ``None``. + internal (bool): Restrict external access to the network. Default + ``False``. + labels (dict): Map of labels to set on the network. Default + ``None``. + enable_ipv6 (bool): Enable IPv6 on the network. Default ``False``. + attachable (bool): If enabled, and the network is in the global + scope, non-service containers on worker nodes will be able to + connect to the network. + scope (str): Specify the network's scope (``local``, ``global`` or + ``swarm``) + ingress (bool): If set, create an ingress network which provides + the routing-mesh in swarm mode. + + Returns: + (:py:class:`Network`): The network that was created. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + A network using the bridge driver: + + >>> client.networks.create("network1", driver="bridge") + + You can also create more advanced networks with custom IPAM + configurations. For example, setting the subnet to + ``192.168.52.0/24`` and gateway address to ``192.168.52.254``. + + .. code-block:: python + + >>> ipam_pool = docker.types.IPAMPool( + subnet='192.168.52.0/24', + gateway='192.168.52.254' + ) + >>> ipam_config = docker.types.IPAMConfig( + pool_configs=[ipam_pool] + ) + >>> client.networks.create( + "network1", + driver="bridge", + ipam=ipam_config + ) + + """ + ... + def get(self, network_id, *args, **kwargs): # -> Model: + """Get a network by its ID. + + Args: + network_id (str): The ID of the network. + verbose (bool): Retrieve the service details across the cluster in + swarm mode. + scope (str): Filter the network by scope (``swarm``, ``global`` + or ``local``). + + Returns: + (:py:class:`Network`) The network. + + Raises: + :py:class:`docker.errors.NotFound` + If the network does not exist. + + :py:class:`docker.errors.APIError` + If the server returns an error. + + """ + ... + def list(self, *args, **kwargs): # -> list[Model]: + """List networks. Similar to the ``docker network ls`` command. + + Args: + names (:py:class:`list`): List of names to filter by. + ids (:py:class:`list`): List of ids to filter by. + filters (dict): Filters to be processed on the network list. + Available filters: + - ``driver=[]`` Matches a network's driver. + - `label` (str|list): format either ``"key"``, ``"key=value"`` + or a list of such. + - ``type=["custom"|"builtin"]`` Filters networks by type. + greedy (bool): Fetch more details for each network individually. + You might want this to get the containers attached to them. + + Returns: + (list of :py:class:`Network`) The networks on the server. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def prune(self, filters=...): ... diff --git a/typings/docker/models/nodes.pyi b/typings/docker/models/nodes.pyi new file mode 100644 index 00000000..9909d294 --- /dev/null +++ b/typings/docker/models/nodes.pyi @@ -0,0 +1,95 @@ +"""This type stub file was generated by pyright.""" + +from .resource import Collection +from .resource import Model + +class Node(Model): + """A node in a swarm.""" + + id_attribute = ... + @property + def version(self): + """The version number of the service. If this is not the same as the + server, the :py:meth:`update` function will not work and you will + need to call :py:meth:`reload` before calling it again. + """ + ... + def update(self, node_spec): + """Update the node's configuration. + + Args: + node_spec (dict): Configuration settings to update. Any values + not provided will be removed. Default: ``None`` + + Returns: + `True` if the request went through. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> node_spec = {'Availability': 'active', + 'Name': 'node-name', + 'Role': 'manager', + 'Labels': {'foo': 'bar'} + } + >>> node.update(node_spec) + + """ + ... + def remove(self, force=...): + """Remove this node from the swarm. + + Args: + force (bool): Force remove an active node. Default: `False` + + Returns: + `True` if the request was successful. + + Raises: + :py:class:`docker.errors.NotFound` + If the node doesn't exist in the swarm. + + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + +class NodeCollection(Collection): + """Nodes on the Docker server.""" + + model = Node + def get(self, node_id): # -> Model: + """Get a node. + + Args: + node_id (string): ID of the node to be inspected. + + Returns: + A :py:class:`Node` object. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def list(self, *args, **kwargs): # -> list[Model]: + """List swarm nodes. + + Args: + filters (dict): Filters to process on the nodes list. Valid + filters: ``id``, ``name``, ``membership`` and ``role``. + Default: ``None`` + + Returns: + A list of :py:class:`Node` objects. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> client.nodes.list(filters={'role': 'manager'}) + """ + ... diff --git a/typings/docker/models/plugins.pyi b/typings/docker/models/plugins.pyi new file mode 100644 index 00000000..a97887d4 --- /dev/null +++ b/typings/docker/models/plugins.pyi @@ -0,0 +1,152 @@ +"""This type stub file was generated by pyright.""" + +from .resource import Collection +from .resource import Model + +class Plugin(Model): + """A plugin on the server.""" + + def __repr__(self): ... + @property + def name(self): # -> None: + """The plugin's name.""" + ... + @property + def enabled(self): # -> None: + """Whether the plugin is enabled.""" + ... + @property + def settings(self): # -> None: + """A dictionary representing the plugin's configuration.""" + ... + def configure(self, options): # -> None: + """Update the plugin's settings. + + Args: + options (dict): A key-value mapping of options. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def disable(self, force=...): # -> None: + """Disable the plugin. + + Args: + force (bool): Force disable. Default: False + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def enable(self, timeout=...): # -> None: + """Enable the plugin. + + Args: + timeout (int): Timeout in seconds. Default: 0 + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def push(self): + """Push the plugin to a remote registry. + + Returns: + A dict iterator streaming the status of the upload. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def remove(self, force=...): + """Remove the plugin from the server. + + Args: + force (bool): Remove even if the plugin is enabled. + Default: False + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def upgrade(self, remote=...): # -> Generator[Unknown, Unknown, None]: + """Upgrade the plugin. + + Args: + remote (string): Remote reference to upgrade to. The + ``:latest`` tag is optional and is the default if omitted. + Default: this plugin's name. + + Returns: + A generator streaming the decoded API logs + """ + ... + +class PluginCollection(Collection): + model = Plugin + def create(self, name, plugin_data_dir, gzip=...): # -> Model: + """Create a new plugin. + + Args: + name (string): The name of the plugin. The ``:latest`` tag is + optional, and is the default if omitted. + plugin_data_dir (string): Path to the plugin data directory. + Plugin data directory must contain the ``config.json`` + manifest file and the ``rootfs`` directory. + gzip (bool): Compress the context using gzip. Default: False + + Returns: + (:py:class:`Plugin`): The newly created plugin. + """ + ... + def get(self, name): # -> Model: + """Gets a plugin. + + Args: + name (str): The name of the plugin. + + Returns: + (:py:class:`Plugin`): The plugin. + + Raises: + :py:class:`docker.errors.NotFound` If the plugin does not + exist. + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def install(self, remote_name, local_name=...): # -> Model: + """Pull and install a plugin. + + Args: + remote_name (string): Remote reference for the plugin to + install. The ``:latest`` tag is optional, and is the + default if omitted. + local_name (string): Local name for the pulled plugin. + The ``:latest`` tag is optional, and is the default if + omitted. Optional. + + Returns: + (:py:class:`Plugin`): The installed plugin + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def list(self): # -> list[Model]: + """List plugins installed on the server. + + Returns: + (list of :py:class:`Plugin`): The plugins. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... diff --git a/typings/docker/models/resource.pyi b/typings/docker/models/resource.pyi new file mode 100644 index 00000000..56d3c533 --- /dev/null +++ b/typings/docker/models/resource.pyi @@ -0,0 +1,38 @@ +"""This type stub file was generated by pyright.""" + +class Model: + """A base class for representing a single object on the server.""" + + id_attribute = ... + def __init__(self, attrs=..., client=..., collection=...) -> None: ... + def __repr__(self): ... + def __eq__(self, other) -> bool: ... + def __hash__(self) -> int: ... + @property + def id(self): # -> None: + """The ID of the object.""" + ... + @property + def short_id(self): + """The ID of the object, truncated to 12 characters.""" + ... + def reload(self): # -> None: + """Load this object from the server again and update ``attrs`` with the + new data. + """ + ... + +class Collection: + """A base class for representing all objects of a particular type on the + server. + """ + + model = ... + def __init__(self, client=...) -> None: ... + def __call__(self, *args, **kwargs): ... + def list(self): ... + def get(self, key): ... + def create(self, attrs=...): ... + def prepare_model(self, attrs): # -> Model: + """Create a model from a set of attributes.""" + ... diff --git a/typings/docker/models/secrets.pyi b/typings/docker/models/secrets.pyi new file mode 100644 index 00000000..82241998 --- /dev/null +++ b/typings/docker/models/secrets.pyi @@ -0,0 +1,56 @@ +"""This type stub file was generated by pyright.""" + +from .resource import Collection +from .resource import Model + +class Secret(Model): + """A secret.""" + + id_attribute = ... + def __repr__(self): ... + @property + def name(self): ... + def remove(self): + """Remove this secret. + + Raises: + :py:class:`docker.errors.APIError` + If secret failed to remove. + """ + ... + +class SecretCollection(Collection): + """Secrets on the Docker server.""" + + model = Secret + def create(self, **kwargs): ... + def get(self, secret_id): # -> Model: + """Get a secret. + + Args: + secret_id (str): Secret ID. + + Returns: + (:py:class:`Secret`): The secret. + + Raises: + :py:class:`docker.errors.NotFound` + If the secret does not exist. + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def list(self, **kwargs): # -> list[Model]: + """List secrets. Similar to the ``docker secret ls`` command. + + Args: + filters (dict): Server-side list filtering options. + + Returns: + (list of :py:class:`Secret`): The secrets. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... diff --git a/typings/docker/models/services.pyi b/typings/docker/models/services.pyi new file mode 100644 index 00000000..522bb734 --- /dev/null +++ b/typings/docker/models/services.pyi @@ -0,0 +1,227 @@ +"""This type stub file was generated by pyright.""" + +from .resource import Collection +from .resource import Model + +class Service(Model): + """A service.""" + + id_attribute = ... + @property + def name(self): + """The service's name.""" + ... + @property + def version(self): + """The version number of the service. If this is not the same as the + server, the :py:meth:`update` function will not work and you will + need to call :py:meth:`reload` before calling it again. + """ + ... + def remove(self): + """Stop and remove the service. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def tasks(self, filters=...): + """List the tasks in this service. + + Args: + filters (dict): A map of filters to process on the tasks list. + Valid filters: ``id``, ``name``, ``node``, + ``label``, and ``desired-state``. + + Returns: + :py:class:`list`: List of task dictionaries. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def update(self, **kwargs): + """Update a service's configuration. Similar to the ``docker service + update`` command. + + Takes the same parameters as :py:meth:`~ServiceCollection.create`. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def logs(self, **kwargs): + """Get log stream for the service. + Note: This method works only for services with the ``json-file`` + or ``journald`` logging drivers. + + Args: + details (bool): Show extra details provided to logs. + Default: ``False`` + follow (bool): Keep connection open to read logs as they are + sent by the Engine. Default: ``False`` + stdout (bool): Return logs from ``stdout``. Default: ``False`` + stderr (bool): Return logs from ``stderr``. Default: ``False`` + since (int): UNIX timestamp for the logs staring point. + Default: 0 + timestamps (bool): Add timestamps to every log line. + tail (string or int): Number of log lines to be returned, + counting from the current end of the logs. Specify an + integer or ``'all'`` to output all log lines. + Default: ``all`` + + Returns: + generator: Logs for the service. + """ + ... + def scale(self, replicas): + """Scale service container. + + Args: + replicas (int): The number of containers that should be running. + + Returns: + bool: ``True`` if successful. + """ + ... + def force_update(self): + """Force update the service even if no changes require it. + + Returns: + bool: ``True`` if successful. + """ + ... + +class ServiceCollection(Collection): + """Services on the Docker server.""" + + model = Service + def create(self, image, command=..., **kwargs): # -> Model: + """Create a service. Similar to the ``docker service create`` command. + + Args: + image (str): The image name to use for the containers. + command (list of str or str): Command to run. + args (list of str): Arguments to the command. + constraints (list of str): :py:class:`~docker.types.Placement` + constraints. + preferences (list of tuple): :py:class:`~docker.types.Placement` + preferences. + maxreplicas (int): :py:class:`~docker.types.Placement` maxreplicas + or (int) representing maximum number of replicas per node. + platforms (list of tuple): A list of platform constraints + expressed as ``(arch, os)`` tuples. + container_labels (dict): Labels to apply to the container. + endpoint_spec (EndpointSpec): Properties that can be configured to + access and load balance a service. Default: ``None``. + env (list of str): Environment variables, in the form + ``KEY=val``. + hostname (string): Hostname to set on the container. + init (boolean): Run an init inside the container that forwards + signals and reaps processes + isolation (string): Isolation technology used by the service's + containers. Only used for Windows containers. + labels (dict): Labels to apply to the service. + log_driver (str): Log driver to use for containers. + log_driver_options (dict): Log driver options. + mode (ServiceMode): Scheduling mode for the service. + Default:``None`` + mounts (list of str): Mounts for the containers, in the form + ``source:target:options``, where options is either + ``ro`` or ``rw``. + name (str): Name to give to the service. + networks (:py:class:`list`): List of network names or IDs or + :py:class:`~docker.types.NetworkAttachmentConfig` to attach the + service to. Default: ``None``. + resources (Resources): Resource limits and reservations. + restart_policy (RestartPolicy): Restart policy for containers. + secrets (list of :py:class:`~docker.types.SecretReference`): List + of secrets accessible to containers for this service. + stop_grace_period (int): Amount of time to wait for + containers to terminate before forcefully killing them. + update_config (UpdateConfig): Specification for the update strategy + of the service. Default: ``None`` + rollback_config (RollbackConfig): Specification for the rollback + strategy of the service. Default: ``None`` + user (str): User to run commands as. + workdir (str): Working directory for commands to run. + tty (boolean): Whether a pseudo-TTY should be allocated. + groups (:py:class:`list`): A list of additional groups that the + container process will run as. + open_stdin (boolean): Open ``stdin`` + read_only (boolean): Mount the container's root filesystem as read + only. + stop_signal (string): Set signal to stop the service's containers + healthcheck (Healthcheck): Healthcheck + configuration for this service. + hosts (:py:class:`dict`): A set of host to IP mappings to add to + the container's `hosts` file. + dns_config (DNSConfig): Specification for DNS + related configurations in resolver configuration file. + configs (:py:class:`list`): List of + :py:class:`~docker.types.ConfigReference` that will be exposed + to the service. + privileges (Privileges): Security options for the service's + containers. + cap_add (:py:class:`list`): A list of kernel capabilities to add to + the default set for the container. + cap_drop (:py:class:`list`): A list of kernel capabilities to drop + from the default set for the container. + sysctls (:py:class:`dict`): A dict of sysctl values to add to the + container + + Returns: + :py:class:`Service`: The created service. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def get(self, service_id, insert_defaults=...): # -> Model: + """Get a service. + + Args: + service_id (str): The ID of the service. + insert_defaults (boolean): If true, default values will be merged + into the output. + + Returns: + :py:class:`Service`: The service. + + Raises: + :py:class:`docker.errors.NotFound` + If the service does not exist. + :py:class:`docker.errors.APIError` + If the server returns an error. + :py:class:`docker.errors.InvalidVersion` + If one of the arguments is not supported with the current + API version. + """ + ... + def list(self, **kwargs): # -> list[Model]: + """List services. + + Args: + filters (dict): Filters to process on the nodes list. Valid + filters: ``id``, ``name`` , ``label`` and ``mode``. + Default: ``None``. + status (bool): Include the service task count of running and + desired tasks. Default: ``None``. + + Returns: + list of :py:class:`Service`: The services. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + +CONTAINER_SPEC_KWARGS = ... +TASK_TEMPLATE_KWARGS = ... +CREATE_SERVICE_KWARGS = ... +PLACEMENT_KWARGS = ... diff --git a/typings/docker/models/swarm.pyi b/typings/docker/models/swarm.pyi new file mode 100644 index 00000000..48f4139f --- /dev/null +++ b/typings/docker/models/swarm.pyi @@ -0,0 +1,143 @@ +"""This type stub file was generated by pyright.""" + +from .resource import Model + +class Swarm(Model): + """The server's Swarm state. This a singleton that must be reloaded to get + the current state of the Swarm. + """ + + id_attribute = ... + def __init__(self, *args, **kwargs) -> None: ... + @property + def version(self): + """The version number of the swarm. If this is not the same as the + server, the :py:meth:`update` function will not work and you will + need to call :py:meth:`reload` before calling it again. + """ + ... + def get_unlock_key(self): ... + def init( + self, + advertise_addr=..., + listen_addr=..., + force_new_cluster=..., + default_addr_pool=..., + subnet_size=..., + data_path_addr=..., + data_path_port=..., + **kwargs, + ): + """Initialize a new swarm on this Engine. + + Args: + advertise_addr (str): Externally reachable address advertised to + other nodes. This can either be an address/port combination in + the form ``192.168.1.1:4567``, or an interface followed by a + port number, like ``eth0:4567``. If the port number is omitted, + the port number from the listen address is used. + + If not specified, it will be automatically detected when + possible. + listen_addr (str): Listen address used for inter-manager + communication, as well as determining the networking interface + used for the VXLAN Tunnel Endpoint (VTEP). This can either be + an address/port combination in the form ``192.168.1.1:4567``, + or an interface followed by a port number, like ``eth0:4567``. + If the port number is omitted, the default swarm listening port + is used. Default: ``0.0.0.0:2377`` + force_new_cluster (bool): Force creating a new Swarm, even if + already part of one. Default: False + default_addr_pool (list of str): Default Address Pool specifies + default subnet pools for global scope networks. Each pool + should be specified as a CIDR block, like '10.0.0.0/8'. + Default: None + subnet_size (int): SubnetSize specifies the subnet size of the + networks created from the default subnet pool. Default: None + data_path_addr (string): Address or interface to use for data path + traffic. For example, 192.168.1.1, or an interface, like eth0. + data_path_port (int): Port number to use for data path traffic. + Acceptable port range is 1024 to 49151. If set to ``None`` or + 0, the default port 4789 will be used. Default: None + task_history_retention_limit (int): Maximum number of tasks + history stored. + snapshot_interval (int): Number of logs entries between snapshot. + keep_old_snapshots (int): Number of snapshots to keep beyond the + current snapshot. + log_entries_for_slow_followers (int): Number of log entries to + keep around to sync up slow followers after a snapshot is + created. + heartbeat_tick (int): Amount of ticks (in seconds) between each + heartbeat. + election_tick (int): Amount of ticks (in seconds) needed without a + leader to trigger a new election. + dispatcher_heartbeat_period (int): The delay for an agent to send + a heartbeat to the dispatcher. + node_cert_expiry (int): Automatic expiry for nodes certificates. + external_ca (dict): Configuration for forwarding signing requests + to an external certificate authority. Use + ``docker.types.SwarmExternalCA``. + name (string): Swarm's name + labels (dict): User-defined key/value metadata. + signing_ca_cert (str): The desired signing CA certificate for all + swarm node TLS leaf certificates, in PEM format. + signing_ca_key (str): The desired signing CA key for all swarm + node TLS leaf certificates, in PEM format. + ca_force_rotate (int): An integer whose purpose is to force swarm + to generate a new signing CA certificate and key, if none have + been specified. + autolock_managers (boolean): If set, generate a key and use it to + lock data stored on the managers. + log_driver (DriverConfig): The default log driver to use for tasks + created in the orchestrator. + + Returns: + (str): The ID of the created node. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> client.swarm.init( + advertise_addr='eth0', listen_addr='0.0.0.0:5000', + force_new_cluster=False, default_addr_pool=['10.20.0.0/16], + subnet_size=24, snapshot_interval=5000, + log_entries_for_slow_followers=1200 + ) + + """ + ... + def join(self, *args, **kwargs): ... + def leave(self, *args, **kwargs): ... + def reload(self): # -> None: + """Inspect the swarm on the server and store the response in + :py:attr:`attrs`. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def unlock(self, key): ... + def update(self, rotate_worker_token=..., rotate_manager_token=..., rotate_manager_unlock_key=..., **kwargs): + """Update the swarm's configuration. + + It takes the same arguments as :py:meth:`init`, except + ``advertise_addr``, ``listen_addr``, and ``force_new_cluster``. In + addition, it takes these arguments: + + Args: + rotate_worker_token (bool): Rotate the worker join token. Default: + ``False``. + rotate_manager_token (bool): Rotate the manager join token. + Default: ``False``. + rotate_manager_unlock_key (bool): Rotate the manager unlock key. + Default: ``False``. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + """ + ... diff --git a/typings/docker/models/volumes.pyi b/typings/docker/models/volumes.pyi new file mode 100644 index 00000000..a8780adb --- /dev/null +++ b/typings/docker/models/volumes.pyi @@ -0,0 +1,85 @@ +"""This type stub file was generated by pyright.""" + +from .resource import Collection +from .resource import Model + +class Volume(Model): + """A volume.""" + + id_attribute = ... + @property + def name(self): + """The name of the volume.""" + ... + def remove(self, force=...): + """Remove this volume. + + Args: + force (bool): Force removal of volumes that were already removed + out of band by the volume driver plugin. + + Raises: + :py:class:`docker.errors.APIError` + If volume failed to remove. + """ + ... + +class VolumeCollection(Collection): + """Volumes on the Docker server.""" + + model = Volume + def create(self, name=..., **kwargs): # -> Model: + """Create a volume. + + Args: + name (str): Name of the volume. If not specified, the engine + generates a name. + driver (str): Name of the driver used to create the volume + driver_opts (dict): Driver options as a key-value dictionary + labels (dict): Labels to set on the volume + + Returns: + (:py:class:`Volume`): The volume created. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + + Example: + >>> volume = client.volumes.create(name='foobar', driver='local', + driver_opts={'foo': 'bar', 'baz': 'false'}, + labels={"key": "value"}) + + """ + ... + def get(self, volume_id): # -> Model: + """Get a volume. + + Args: + volume_id (str): Volume name. + + Returns: + (:py:class:`Volume`): The volume. + + Raises: + :py:class:`docker.errors.NotFound` + If the volume does not exist. + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def list(self, **kwargs): # -> list[Unknown] | list[Model]: + """List volumes. Similar to the ``docker volume ls`` command. + + Args: + filters (dict): Server-side list filtering options. + + Returns: + (list of :py:class:`Volume`): The volumes. + + Raises: + :py:class:`docker.errors.APIError` + If the server returns an error. + """ + ... + def prune(self, filters=...): ... diff --git a/typings/docker/transport/__init__.pyi b/typings/docker/transport/__init__.pyi new file mode 100644 index 00000000..6c6d7b34 --- /dev/null +++ b/typings/docker/transport/__init__.pyi @@ -0,0 +1,5 @@ +from .npipeconn import NpipeHTTPAdapter as NpipeHTTPAdapter +from .npipesocket import NpipeSocket as NpipeSocket +from .sshconn import SSHHTTPAdapter as SSHHTTPAdapter +from .ssladapter import SSLHTTPAdapter as SSLHTTPAdapter +from .unixconn import UnixHTTPAdapter as UnixHTTPAdapter diff --git a/typings/docker/transport/basehttpadapter.pyi b/typings/docker/transport/basehttpadapter.pyi new file mode 100644 index 00000000..1200bb42 --- /dev/null +++ b/typings/docker/transport/basehttpadapter.pyi @@ -0,0 +1,6 @@ +"""This type stub file was generated by pyright.""" + +import requests.adapters + +class BaseHTTPAdapter(requests.adapters.HTTPAdapter): + def close(self): ... diff --git a/typings/docker/transport/npipeconn.pyi b/typings/docker/transport/npipeconn.pyi new file mode 100644 index 00000000..3df449a3 --- /dev/null +++ b/typings/docker/transport/npipeconn.pyi @@ -0,0 +1,20 @@ +"""This type stub file was generated by pyright.""" + +import urllib3 +import urllib3.connection +from docker.transport.basehttpadapter import BaseHTTPAdapter + +RecentlyUsedContainer = ... + +class NpipeHTTPConnection(urllib3.connection.HTTPConnection): + def __init__(self, npipe_path, timeout=...) -> None: ... + def connect(self): ... + +class NpipeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): + def __init__(self, npipe_path, timeout=..., maxsize=...) -> None: ... + +class NpipeHTTPAdapter(BaseHTTPAdapter): + __attrs__ = ... + def __init__(self, base_url, timeout=..., pool_connections=..., max_pool_size=...) -> None: ... + def get_connection(self, url, proxies=...): ... + def request_url(self, request, proxies): ... diff --git a/typings/docker/transport/npipesocket.pyi b/typings/docker/transport/npipesocket.pyi new file mode 100644 index 00000000..660e73fa --- /dev/null +++ b/typings/docker/transport/npipesocket.pyi @@ -0,0 +1,66 @@ +"""This type stub file was generated by pyright.""" + +import io + +cERROR_PIPE_BUSY = ... +cSECURITY_SQOS_PRESENT = ... +cSECURITY_ANONYMOUS = ... +MAXIMUM_RETRY_COUNT = ... + +def check_closed(f): ... + +class NpipeSocket: + """Partial implementation of the socket API over windows named pipes. + This implementation is only designed to be used as a client socket, + and server-specific methods (bind, listen, accept...) are not + implemented. + """ + + def __init__(self, handle=...) -> None: ... + def accept(self): ... + def bind(self, address): ... + def close(self): ... + @check_closed + def connect(self, address, retry_count=...): ... + @check_closed + def connect_ex(self, address): ... + @check_closed + def detach(self): ... + @check_closed + def dup(self): ... + def getpeername(self): ... + def getsockname(self): ... + def getsockopt(self, level, optname, buflen=...): ... + def ioctl(self, control, option): ... + def listen(self, backlog): ... + def makefile(self, mode=..., bufsize=...): ... + @check_closed + def recv(self, bufsize, flags=...): ... + @check_closed + def recvfrom(self, bufsize, flags=...): ... + @check_closed + def recvfrom_into(self, buf, nbytes=..., flags=...): ... + @check_closed + def recv_into(self, buf, nbytes=...): ... + @check_closed + def send(self, string, flags=...): ... + @check_closed + def sendall(self, string, flags=...): ... + @check_closed + def sendto(self, string, address): ... + def setblocking(self, flag): ... + def settimeout(self, value): ... + def gettimeout(self): ... + def setsockopt(self, level, optname, value): ... + @check_closed + def shutdown(self, how): ... + +class NpipeFileIOBase(io.RawIOBase): + def __init__(self, npipe_socket) -> None: ... + def close(self): ... + def fileno(self): ... + def isatty(self): ... + def readable(self): ... + def readinto(self, buf): ... + def seekable(self): ... + def writable(self): ... diff --git a/typings/docker/transport/sshconn.pyi b/typings/docker/transport/sshconn.pyi new file mode 100644 index 00000000..3749094f --- /dev/null +++ b/typings/docker/transport/sshconn.pyi @@ -0,0 +1,32 @@ +"""This type stub file was generated by pyright.""" + +import socket + +import urllib3 +import urllib3.connection +from docker.transport.basehttpadapter import BaseHTTPAdapter + +RecentlyUsedContainer = ... + +class SSHSocket(socket.socket): + def __init__(self, host) -> None: ... + def connect(self, **kwargs): ... + def sendall(self, data): ... + def send(self, data): ... + def recv(self, n): ... + def makefile(self, mode): ... + def close(self): ... + +class SSHConnection(urllib3.connection.HTTPConnection): + def __init__(self, ssh_transport=..., timeout=..., host=...) -> None: ... + def connect(self): ... + +class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): + scheme = ... + def __init__(self, ssh_client=..., timeout=..., maxsize=..., host=...) -> None: ... + +class SSHHTTPAdapter(BaseHTTPAdapter): + __attrs__ = ... + def __init__(self, base_url, timeout=..., pool_connections=..., max_pool_size=..., shell_out=...) -> None: ... + def get_connection(self, url, proxies=...): ... + def close(self): ... diff --git a/typings/docker/transport/ssladapter.pyi b/typings/docker/transport/ssladapter.pyi new file mode 100644 index 00000000..66d5cbc7 --- /dev/null +++ b/typings/docker/transport/ssladapter.pyi @@ -0,0 +1,26 @@ +"""This type stub file was generated by pyright.""" + +import urllib3 +from docker.transport.basehttpadapter import BaseHTTPAdapter + +""" Resolves OpenSSL issues in some servers: + https://lukasa.co.uk/2013/01/Choosing_SSL_Version_In_Requests/ + https://github.com/kennethreitz/requests/pull/799 +""" +PoolManager = urllib3.poolmanager.PoolManager + +class SSLHTTPAdapter(BaseHTTPAdapter): + """An HTTPS Transport Adapter that uses an arbitrary SSL version.""" + + __attrs__ = ... + def __init__(self, ssl_version=..., assert_hostname=..., assert_fingerprint=..., **kwargs) -> None: ... + def init_poolmanager(self, connections, maxsize, block=...): ... + def get_connection(self, *args, **kwargs): + """Ensure assert_hostname is set correctly on our pool. + + We already take care of a normal poolmanager via init_poolmanager + + But we still need to take care of when there is a proxy poolmanager + """ + ... + def can_override_ssl_version(self): ... diff --git a/typings/docker/transport/unixconn.pyi b/typings/docker/transport/unixconn.pyi new file mode 100644 index 00000000..d0e7bf72 --- /dev/null +++ b/typings/docker/transport/unixconn.pyi @@ -0,0 +1,20 @@ +"""This type stub file was generated by pyright.""" + +import urllib3 +import urllib3.connection +from docker.transport.basehttpadapter import BaseHTTPAdapter + +RecentlyUsedContainer = ... + +class UnixHTTPConnection(urllib3.connection.HTTPConnection): + def __init__(self, base_url, unix_socket, timeout=...) -> None: ... + def connect(self): ... + +class UnixHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): + def __init__(self, base_url, socket_path, timeout=..., maxsize=...) -> None: ... + +class UnixHTTPAdapter(BaseHTTPAdapter): + __attrs__ = ... + def __init__(self, socket_url, timeout=..., pool_connections=..., max_pool_size=...) -> None: ... + def get_connection(self, url, proxies=...): ... + def request_url(self, request, proxies): ... diff --git a/typings/docker/types/__init__.pyi b/typings/docker/types/__init__.pyi new file mode 100644 index 00000000..b5f7d939 --- /dev/null +++ b/typings/docker/types/__init__.pyi @@ -0,0 +1,30 @@ +from .containers import ContainerConfig as ContainerConfig +from .containers import DeviceRequest as DeviceRequest +from .containers import HostConfig as HostConfig +from .containers import LogConfig as LogConfig +from .containers import Ulimit as Ulimit +from .daemon import CancellableStream as CancellableStream +from .healthcheck import Healthcheck as Healthcheck +from .networks import EndpointConfig as EndpointConfig +from .networks import IPAMConfig as IPAMConfig +from .networks import IPAMPool as IPAMPool +from .networks import NetworkingConfig as NetworkingConfig +from .services import ConfigReference as ConfigReference +from .services import ContainerSpec as ContainerSpec +from .services import DNSConfig as DNSConfig +from .services import DriverConfig as DriverConfig +from .services import EndpointSpec as EndpointSpec +from .services import Mount as Mount +from .services import NetworkAttachmentConfig as NetworkAttachmentConfig +from .services import Placement as Placement +from .services import PlacementPreference as PlacementPreference +from .services import Privileges as Privileges +from .services import Resources as Resources +from .services import RestartPolicy as RestartPolicy +from .services import RollbackConfig as RollbackConfig +from .services import SecretReference as SecretReference +from .services import ServiceMode as ServiceMode +from .services import TaskTemplate as TaskTemplate +from .services import UpdateConfig as UpdateConfig +from .swarm import SwarmExternalCA as SwarmExternalCA +from .swarm import SwarmSpec as SwarmSpec diff --git a/typings/docker/types/base.pyi b/typings/docker/types/base.pyi new file mode 100644 index 00000000..8a9f8907 --- /dev/null +++ b/typings/docker/types/base.pyi @@ -0,0 +1,4 @@ +"""This type stub file was generated by pyright.""" + +class DictType(dict): + def __init__(self, init) -> None: ... diff --git a/typings/docker/types/containers.pyi b/typings/docker/types/containers.pyi new file mode 100644 index 00000000..8da7346f --- /dev/null +++ b/typings/docker/types/containers.pyi @@ -0,0 +1,238 @@ +from typing import Any + +from .base import DictType + +class LogConfigTypesEnum: + _values = ... + +class LogConfig(DictType): + """Configure logging for a container, when provided as an argument to + :py:meth:`~docker.api.container.ContainerApiMixin.create_host_config`. + You may refer to the + `official logging driver documentation `_ + for more information. + + Args: + type (str): Indicate which log driver to use. A set of valid drivers + is provided as part of the :py:attr:`LogConfig.types` + enum. Other values may be accepted depending on the engine version + and available logging plugins. + config (dict): A driver-dependent configuration dictionary. Please + refer to the driver's documentation for a list of valid config + keys. + + Example: + >>> from docker.types import LogConfig + >>> lc = LogConfig(type=LogConfig.types.JSON, config={ + ... 'max-size': '1g', + ... 'labels': 'production_status,geo' + ... }) + >>> hc = client.create_host_config(log_config=lc) + >>> container = client.create_container('busybox', 'true', + ... host_config=hc) + >>> client.inspect_container(container)['HostConfig']['LogConfig'] + {'Type': 'json-file', 'Config': {'labels': 'production_status,geo', 'max-size': '1g'}} + """ + + types = LogConfigTypesEnum + def __init__(self, **kwargs: Any) -> None: ... + @property + def type(self): ... + @type.setter + def type(self, value: Any): ... + @property + def config(self): ... + def set_config_value(self, key: str, value: Any) -> None: + """Set a the value for ``key`` to ``value`` inside the ``config`` + dict. + """ + ... + def unset_config(self, key: str) -> None: + """Remove the ``key`` property from the ``config`` dict.""" + ... + +class Ulimit(DictType): + """Create a ulimit declaration to be used with + :py:meth:`~docker.api.container.ContainerApiMixin.create_host_config`. + + Args: + name (str): Which ulimit will this apply to. The valid names can be + found in '/etc/security/limits.conf' on a gnu/linux system. + soft (int): The soft limit for this ulimit. Optional. + hard (int): The hard limit for this ulimit. Optional. + + Example: + >>> nproc_limit = docker.types.Ulimit(name='nproc', soft=1024) + >>> hc = client.create_host_config(ulimits=[nproc_limit]) + >>> container = client.create_container( + 'busybox', 'true', host_config=hc + ) + >>> client.inspect_container(container)['HostConfig']['Ulimits'] + [{'Name': 'nproc', 'Hard': 0, 'Soft': 1024}] + + """ + + def __init__(self, **kwargs) -> None: ... + @property + def name(self): ... + @name.setter + def name(self, value): ... + @property + def soft(self): ... + @soft.setter + def soft(self, value): ... + @property + def hard(self): ... + @hard.setter + def hard(self, value): ... + +class DeviceRequest(DictType): + """Create a device request to be used with + :py:meth:`~docker.api.container.ContainerApiMixin.create_host_config`. + + Args: + driver (str): Which driver to use for this device. Optional. + count (int): Number or devices to request. Optional. + Set to -1 to request all available devices. + device_ids (list): List of strings for device IDs. Optional. + Set either ``count`` or ``device_ids``. + capabilities (list): List of lists of strings to request + capabilities. Optional. The global list acts like an OR, + and the sub-lists are AND. The driver will try to satisfy + one of the sub-lists. + Available capabilities for the ``nvidia`` driver can be found + `here `_. + options (dict): Driver-specific options. Optional. + """ + + def __init__( + self, + count: int | None = ..., + driver: str | None = ..., + device_ids: list[str] | None = ..., + capabilities: list[list[str]] | None = ..., + options: dict[str, str] | None = ..., + ) -> None: ... + @property + def driver(self) -> str: ... + @driver.setter + def driver(self, value: str) -> None: ... + @property + def count(self) -> int: ... + @count.setter + def count(self, value: int) -> None: ... + @property + def device_ids(self): ... + @device_ids.setter + def device_ids(self, value): ... + @property + def capabilities(self): ... + @capabilities.setter + def capabilities(self, value): ... + @property + def options(self): ... + @options.setter + def options(self, value): ... + +class HostConfig(dict): + def __init__( + self, + version, + binds=..., + port_bindings=..., + lxc_conf=..., + publish_all_ports=..., + links=..., + privileged=..., + dns=..., + dns_search=..., + volumes_from=..., + network_mode=..., + restart_policy=..., + cap_add=..., + cap_drop=..., + devices=..., + extra_hosts=..., + read_only=..., + pid_mode=..., + ipc_mode=..., + security_opt=..., + ulimits=..., + log_config=..., + mem_limit=..., + memswap_limit=..., + mem_reservation=..., + kernel_memory=..., + mem_swappiness=..., + cgroup_parent=..., + group_add=..., + cpu_quota=..., + cpu_period=..., + blkio_weight=..., + blkio_weight_device=..., + device_read_bps=..., + device_write_bps=..., + device_read_iops=..., + device_write_iops=..., + oom_kill_disable=..., + shm_size=..., + sysctls=..., + tmpfs=..., + oom_score_adj=..., + dns_opt=..., + cpu_shares=..., + cpuset_cpus=..., + userns_mode=..., + uts_mode=..., + pids_limit=..., + isolation=..., + auto_remove=..., + storage_opt=..., + init=..., + init_path=..., + volume_driver=..., + cpu_count=..., + cpu_percent=..., + nano_cpus=..., + cpuset_mems=..., + runtime=..., + mounts=..., + cpu_rt_period=..., + cpu_rt_runtime=..., + device_cgroup_rules=..., + device_requests=..., + cgroupns=..., + ) -> None: ... + +def host_config_type_error(param, param_value, expected): ... +def host_config_version_error(param, version, less_than=...): ... +def host_config_value_error(param, param_value): ... +def host_config_incompatible_error(param, param_value, incompatible_param): ... + +class ContainerConfig(dict): + def __init__( + self, + version, + image, + command, + hostname=..., + user=..., + detach=..., + stdin_open=..., + tty=..., + ports=..., + environment=..., + volumes=..., + network_disabled=..., + entrypoint=..., + working_dir=..., + domainname=..., + host_config=..., + mac_address=..., + labels=..., + stop_signal=..., + networking_config=..., + healthcheck=..., + stop_timeout=..., + runtime=..., + ) -> None: ... diff --git a/typings/docker/types/daemon.pyi b/typings/docker/types/daemon.pyi new file mode 100644 index 00000000..5bcb9da6 --- /dev/null +++ b/typings/docker/types/daemon.pyi @@ -0,0 +1,21 @@ +"""This type stub file was generated by pyright.""" + +class CancellableStream: + """Stream wrapper for real-time events, logs, etc. from the server. + + Example: + >>> events = client.events() + >>> for event in events: + ... print(event) + >>> # and cancel from another thread + >>> events.close() + """ + + def __init__(self, stream, response) -> None: ... + def __iter__(self): ... + def __next__(self): ... + + next = ... + def close(self): # -> None: + """Closes the event streaming.""" + ... diff --git a/typings/docker/types/healthcheck.pyi b/typings/docker/types/healthcheck.pyi new file mode 100644 index 00000000..92e3198e --- /dev/null +++ b/typings/docker/types/healthcheck.pyi @@ -0,0 +1,51 @@ +"""This type stub file was generated by pyright.""" + +from .base import DictType + +class Healthcheck(DictType): + """Defines a healthcheck configuration for a container or service. + + Args: + test (:py:class:`list` or str): Test to perform to determine + container health. Possible values: + + - Empty list: Inherit healthcheck from parent image + - ``["NONE"]``: Disable healthcheck + - ``["CMD", args...]``: exec arguments directly. + - ``["CMD-SHELL", command]``: Run command in the system's + default shell. + + If a string is provided, it will be used as a ``CMD-SHELL`` + command. + interval (int): The time to wait between checks in nanoseconds. It + should be 0 or at least 1000000 (1 ms). + timeout (int): The time to wait before considering the check to + have hung. It should be 0 or at least 1000000 (1 ms). + retries (int): The number of consecutive failures needed to + consider a container as unhealthy. + start_period (int): Start period for the container to + initialize before starting health-retries countdown in + nanoseconds. It should be 0 or at least 1000000 (1 ms). + """ + + def __init__(self, **kwargs) -> None: ... + @property + def test(self): ... + @test.setter + def test(self, value): ... + @property + def interval(self): ... + @interval.setter + def interval(self, value): ... + @property + def timeout(self): ... + @timeout.setter + def timeout(self, value): ... + @property + def retries(self): ... + @retries.setter + def retries(self, value): ... + @property + def start_period(self): ... + @start_period.setter + def start_period(self, value): ... diff --git a/typings/docker/types/networks.pyi b/typings/docker/types/networks.pyi new file mode 100644 index 00000000..b143222a --- /dev/null +++ b/typings/docker/types/networks.pyi @@ -0,0 +1,66 @@ +"""This type stub file was generated by pyright.""" + +class EndpointConfig(dict): + def __init__( + self, + version, + aliases=..., + links=..., + ipv4_address=..., + ipv6_address=..., + link_local_ips=..., + driver_opt=..., + mac_address=..., + ) -> None: ... + +class NetworkingConfig(dict): + def __init__(self, endpoints_config=...) -> None: ... + +class IPAMConfig(dict): + """Create an IPAM (IP Address Management) config dictionary to be used with + :py:meth:`~docker.api.network.NetworkApiMixin.create_network`. + + Args: + driver (str): The IPAM driver to use. Defaults to ``default``. + pool_configs (:py:class:`list`): A list of pool configurations + (:py:class:`~docker.types.IPAMPool`). Defaults to empty list. + options (dict): Driver options as a key-value dictionary. + Defaults to `None`. + + Example: + >>> ipam_config = docker.types.IPAMConfig(driver='default') + >>> network = client.create_network('network1', ipam=ipam_config) + + """ + + def __init__(self, driver=..., pool_configs=..., options=...) -> None: ... + +class IPAMPool(dict): + """Create an IPAM pool config dictionary to be added to the + ``pool_configs`` parameter of + :py:class:`~docker.types.IPAMConfig`. + + Args: + subnet (str): Custom subnet for this IPAM pool using the CIDR + notation. Defaults to ``None``. + iprange (str): Custom IP range for endpoints in this IPAM pool using + the CIDR notation. Defaults to ``None``. + gateway (str): Custom IP address for the pool's gateway. + aux_addresses (dict): A dictionary of ``key -> ip_address`` + relationships specifying auxiliary addresses that need to be + allocated by the IPAM driver. + + Example: + >>> ipam_pool = docker.types.IPAMPool( + subnet='124.42.0.0/16', + iprange='124.42.0.0/24', + gateway='124.42.0.254', + aux_addresses={ + 'reserved1': '124.42.1.1' + } + ) + >>> ipam_config = docker.types.IPAMConfig( + pool_configs=[ipam_pool]) + """ + + def __init__(self, subnet=..., iprange=..., gateway=..., aux_addresses=...) -> None: ... diff --git a/typings/docker/types/services.pyi b/typings/docker/types/services.pyi new file mode 100644 index 00000000..fad6c534 --- /dev/null +++ b/typings/docker/types/services.pyi @@ -0,0 +1,427 @@ +"""This type stub file was generated by pyright.""" + +from ..utils import check_resource + +class TaskTemplate(dict): + """Describe the task specification to be used when creating or updating a + service. + + Args: + container_spec (ContainerSpec): Container settings for containers + started as part of this task. + log_driver (DriverConfig): Log configuration for containers created as + part of the service. + resources (Resources): Resource requirements which apply to each + individual container created as part of the service. + restart_policy (RestartPolicy): Specification for the restart policy + which applies to containers created as part of this service. + placement (Placement): Placement instructions for the scheduler. + If a list is passed instead, it is assumed to be a list of + constraints as part of a :py:class:`Placement` object. + networks (:py:class:`list`): List of network names or IDs or + :py:class:`NetworkAttachmentConfig` to attach the service to. + force_update (int): A counter that triggers an update even if no + relevant parameters have been changed. + """ + + def __init__( + self, + container_spec, + resources=..., + restart_policy=..., + placement=..., + log_driver=..., + networks=..., + force_update=..., + ) -> None: ... + @property + def container_spec(self): ... + @property + def resources(self): ... + @property + def restart_policy(self): ... + @property + def placement(self): ... + +class ContainerSpec(dict): + """Describes the behavior of containers that are part of a task, and is used + when declaring a :py:class:`~docker.types.TaskTemplate`. + + Args: + image (string): The image name to use for the container. + command (string or list): The command to be run in the image. + args (:py:class:`list`): Arguments to the command. + hostname (string): The hostname to set on the container. + env (dict): Environment variables. + workdir (string): The working directory for commands to run in. + user (string): The user inside the container. + labels (dict): A map of labels to associate with the service. + mounts (:py:class:`list`): A list of specifications for mounts to be + added to containers created as part of the service. See the + :py:class:`~docker.types.Mount` class for details. + stop_grace_period (int): Amount of time to wait for the container to + terminate before forcefully killing it. + secrets (:py:class:`list`): List of :py:class:`SecretReference` to be + made available inside the containers. + tty (boolean): Whether a pseudo-TTY should be allocated. + groups (:py:class:`list`): A list of additional groups that the + container process will run as. + open_stdin (boolean): Open ``stdin`` + read_only (boolean): Mount the container's root filesystem as read + only. + stop_signal (string): Set signal to stop the service's containers + healthcheck (Healthcheck): Healthcheck + configuration for this service. + hosts (:py:class:`dict`): A set of host to IP mappings to add to + the container's ``hosts`` file. + dns_config (DNSConfig): Specification for DNS + related configurations in resolver configuration file. + configs (:py:class:`list`): List of :py:class:`ConfigReference` that + will be exposed to the service. + privileges (Privileges): Security options for the service's containers. + isolation (string): Isolation technology used by the service's + containers. Only used for Windows containers. + init (boolean): Run an init inside the container that forwards signals + and reaps processes. + cap_add (:py:class:`list`): A list of kernel capabilities to add to the + default set for the container. + cap_drop (:py:class:`list`): A list of kernel capabilities to drop from + the default set for the container. + sysctls (:py:class:`dict`): A dict of sysctl values to add to + the container + """ + + def __init__( + self, + image, + command=..., + args=..., + hostname=..., + env=..., + workdir=..., + user=..., + labels=..., + mounts=..., + stop_grace_period=..., + secrets=..., + tty=..., + groups=..., + open_stdin=..., + read_only=..., + stop_signal=..., + healthcheck=..., + hosts=..., + dns_config=..., + configs=..., + privileges=..., + isolation=..., + init=..., + cap_add=..., + cap_drop=..., + sysctls=..., + ) -> None: ... + +class Mount(dict): + """Describes a mounted folder's configuration inside a container. A list of + :py:class:`Mount` would be used as part of a + :py:class:`~docker.types.ContainerSpec`. + + Args: + target (string): Container path. + source (string): Mount source (e.g. a volume name or a host path). + type (string): The mount type (``bind`` / ``volume`` / ``tmpfs`` / + ``npipe``). Default: ``volume``. + read_only (bool): Whether the mount should be read-only. + consistency (string): The consistency requirement for the mount. One of + ``default```, ``consistent``, ``cached``, ``delegated``. + propagation (string): A propagation mode with the value ``[r]private``, + ``[r]shared``, or ``[r]slave``. Only valid for the ``bind`` type. + no_copy (bool): False if the volume should be populated with the data + from the target. Default: ``False``. Only valid for the ``volume`` + type. + labels (dict): User-defined name and labels for the volume. Only valid + for the ``volume`` type. + driver_config (DriverConfig): Volume driver configuration. Only valid + for the ``volume`` type. + tmpfs_size (int or string): The size for the tmpfs mount in bytes. + tmpfs_mode (int): The permission mode for the tmpfs mount. + """ + + def __init__( + self, + target, + source, + type=..., + read_only=..., + consistency=..., + propagation=..., + no_copy=..., + labels=..., + driver_config=..., + tmpfs_size=..., + tmpfs_mode=..., + ) -> None: ... + @classmethod + def parse_mount_string(cls, string): ... + +class Resources(dict): + """Configures resource allocation for containers when made part of a + :py:class:`~docker.types.ContainerSpec`. + + Args: + cpu_limit (int): CPU limit in units of 10^9 CPU shares. + mem_limit (int): Memory limit in Bytes. + cpu_reservation (int): CPU reservation in units of 10^9 CPU shares. + mem_reservation (int): Memory reservation in Bytes. + generic_resources (dict or :py:class:`list`): Node level generic + resources, for example a GPU, using the following format: + ``{ resource_name: resource_value }``. Alternatively, a list of + of resource specifications as defined by the Engine API. + """ + + def __init__( + self, cpu_limit=..., mem_limit=..., cpu_reservation=..., mem_reservation=..., generic_resources=... + ) -> None: ... + +class UpdateConfig(dict): + """Used to specify the way container updates should be performed by a service. + + Args: + parallelism (int): Maximum number of tasks to be updated in one + iteration (0 means unlimited parallelism). Default: 0. + delay (int): Amount of time between updates, in nanoseconds. + failure_action (string): Action to take if an updated task fails to + run, or stops running during the update. Acceptable values are + ``continue``, ``pause``, as well as ``rollback`` since API v1.28. + Default: ``continue`` + monitor (int): Amount of time to monitor each updated task for + failures, in nanoseconds. + max_failure_ratio (float): The fraction of tasks that may fail during + an update before the failure action is invoked, specified as a + floating point number between 0 and 1. Default: 0 + order (string): Specifies the order of operations when rolling out an + updated task. Either ``start-first`` or ``stop-first`` are accepted. + """ + + def __init__( + self, parallelism=..., delay=..., failure_action=..., monitor=..., max_failure_ratio=..., order=... + ) -> None: ... + +class RollbackConfig(UpdateConfig): + """Used to specify the way container rollbacks should be performed by a + service. + + Args: + parallelism (int): Maximum number of tasks to be rolled back in one + iteration (0 means unlimited parallelism). Default: 0 + delay (int): Amount of time between rollbacks, in nanoseconds. + failure_action (string): Action to take if a rolled back task fails to + run, or stops running during the rollback. Acceptable values are + ``continue``, ``pause`` or ``rollback``. + Default: ``continue`` + monitor (int): Amount of time to monitor each rolled back task for + failures, in nanoseconds. + max_failure_ratio (float): The fraction of tasks that may fail during + a rollback before the failure action is invoked, specified as a + floating point number between 0 and 1. Default: 0 + order (string): Specifies the order of operations when rolling out a + rolled back task. Either ``start-first`` or ``stop-first`` are + accepted. + """ + +class RestartConditionTypesEnum: + _values = ... + +class RestartPolicy(dict): + """Used when creating a :py:class:`~docker.types.ContainerSpec`, + dictates whether a container should restart after stopping or failing. + + Args: + condition (string): Condition for restart (``none``, ``on-failure``, + or ``any``). Default: `none`. + delay (int): Delay between restart attempts. Default: 0 + max_attempts (int): Maximum attempts to restart a given container + before giving up. Default value is 0, which is ignored. + window (int): Time window used to evaluate the restart policy. Default + value is 0, which is unbounded. + """ + + condition_types = RestartConditionTypesEnum + def __init__(self, condition=..., delay=..., max_attempts=..., window=...) -> None: ... + +class DriverConfig(dict): + """Indicates which driver to use, as well as its configuration. Can be used + as ``log_driver`` in a :py:class:`~docker.types.ContainerSpec`, + for the `driver_config` in a volume :py:class:`~docker.types.Mount`, or + as the driver object in + :py:meth:`create_secret`. + + Args: + name (string): Name of the driver to use. + options (dict): Driver-specific options. Default: ``None``. + """ + + def __init__(self, name, options=...) -> None: ... + +class EndpointSpec(dict): + """Describes properties to access and load-balance a service. + + Args: + mode (string): The mode of resolution to use for internal load + balancing between tasks (``'vip'`` or ``'dnsrr'``). Defaults to + ``'vip'`` if not provided. + ports (dict): Exposed ports that this service is accessible on from the + outside, in the form of ``{ published_port: target_port }`` or + ``{ published_port: }``. Port config tuple format + is ``(target_port [, protocol [, publish_mode]])``. + Ports can only be provided if the ``vip`` resolution mode is used. + """ + + def __init__(self, mode=..., ports=...) -> None: ... + +def convert_service_ports(ports): ... + +class ServiceMode(dict): + """Indicate whether a service or a job should be deployed as a replicated + or global service, and associated parameters. + + Args: + mode (string): Can be either ``replicated``, ``global``, + ``replicated-job`` or ``global-job`` + replicas (int): Number of replicas. For replicated services only. + concurrency (int): Number of concurrent jobs. For replicated job + services only. + """ + + def __init__(self, mode, replicas=..., concurrency=...) -> None: ... + @property + def replicas(self): ... + +class SecretReference(dict): + """Secret reference to be used as part of a :py:class:`ContainerSpec`. + Describes how a secret is made accessible inside the service's + containers. + + Args: + secret_id (string): Secret's ID + secret_name (string): Secret's name as defined at its creation. + filename (string): Name of the file containing the secret. Defaults + to the secret's name if not specified. + uid (string): UID of the secret file's owner. Default: 0 + gid (string): GID of the secret file's group. Default: 0 + mode (int): File access mode inside the container. Default: 0o444 + """ + + @check_resource("secret_id") + def __init__(self, secret_id, secret_name, filename=..., uid=..., gid=..., mode=...) -> None: ... + +class ConfigReference(dict): + """Config reference to be used as part of a :py:class:`ContainerSpec`. + Describes how a config is made accessible inside the service's + containers. + + Args: + config_id (string): Config's ID + config_name (string): Config's name as defined at its creation. + filename (string): Name of the file containing the config. Defaults + to the config's name if not specified. + uid (string): UID of the config file's owner. Default: 0 + gid (string): GID of the config file's group. Default: 0 + mode (int): File access mode inside the container. Default: 0o444 + """ + + @check_resource("config_id") + def __init__(self, config_id, config_name, filename=..., uid=..., gid=..., mode=...) -> None: ... + +class Placement(dict): + """Placement constraints to be used as part of a :py:class:`TaskTemplate`. + + Args: + constraints (:py:class:`list` of str): A list of constraints + preferences (:py:class:`list` of tuple): Preferences provide a way + to make the scheduler aware of factors such as topology. They + are provided in order from highest to lowest precedence and + are expressed as ``(strategy, descriptor)`` tuples. See + :py:class:`PlacementPreference` for details. + maxreplicas (int): Maximum number of replicas per node + platforms (:py:class:`list` of tuple): A list of platforms + expressed as ``(arch, os)`` tuples + """ + + def __init__(self, constraints=..., preferences=..., platforms=..., maxreplicas=...) -> None: ... + +class PlacementPreference(dict): + """Placement preference to be used as an element in the list of + preferences for :py:class:`Placement` objects. + + Args: + strategy (string): The placement strategy to implement. Currently, + the only supported strategy is ``spread``. + descriptor (string): A label descriptor. For the spread strategy, + the scheduler will try to spread tasks evenly over groups of + nodes identified by this label. + """ + + def __init__(self, strategy, descriptor) -> None: ... + +class DNSConfig(dict): + """Specification for DNS related configurations in resolver configuration + file (``resolv.conf``). Part of a :py:class:`ContainerSpec` definition. + + Args: + nameservers (:py:class:`list`): The IP addresses of the name + servers. + search (:py:class:`list`): A search list for host-name lookup. + options (:py:class:`list`): A list of internal resolver variables + to be modified (e.g., ``debug``, ``ndots:3``, etc.). + """ + + def __init__(self, nameservers=..., search=..., options=...) -> None: ... + +class Privileges(dict): + r"""Security options for a service's containers. + Part of a :py:class:`ContainerSpec` definition. + + Args: + credentialspec_file (str): Load credential spec from this file. + The file is read by the daemon, and must be present in the + CredentialSpecs subdirectory in the docker data directory, + which defaults to ``C:\ProgramData\Docker\`` on Windows. + Can not be combined with credentialspec_registry. + + credentialspec_registry (str): Load credential spec from this value + in the Windows registry. The specified registry value must be + located in: ``HKLM\SOFTWARE\Microsoft\Windows NT\CurrentVersion + \Virtualization\Containers\CredentialSpecs``. + Can not be combined with credentialspec_file. + + selinux_disable (boolean): Disable SELinux + selinux_user (string): SELinux user label + selinux_role (string): SELinux role label + selinux_type (string): SELinux type label + selinux_level (string): SELinux level label + """ + def __init__( + self, + credentialspec_file=..., + credentialspec_registry=..., + selinux_disable=..., + selinux_user=..., + selinux_role=..., + selinux_type=..., + selinux_level=..., + ) -> None: ... + +class NetworkAttachmentConfig(dict): + """Network attachment options for a service. + + Args: + target (str): The target network for attachment. + Can be a network name or ID. + aliases (:py:class:`list`): A list of discoverable alternate names + for the service. + options (:py:class:`dict`): Driver attachment options for the + network target. + """ + + def __init__(self, target, aliases=..., options=...) -> None: ... diff --git a/typings/docker/types/swarm.pyi b/typings/docker/types/swarm.pyi new file mode 100644 index 00000000..d186c13e --- /dev/null +++ b/typings/docker/types/swarm.pyi @@ -0,0 +1,48 @@ +"""This type stub file was generated by pyright.""" + +class SwarmSpec(dict): + """Describe a Swarm's configuration and options. Use + :py:meth:`~docker.api.swarm.SwarmApiMixin.create_swarm_spec` + to instantiate. + """ + + def __init__( + self, + version, + task_history_retention_limit=..., + snapshot_interval=..., + keep_old_snapshots=..., + log_entries_for_slow_followers=..., + heartbeat_tick=..., + election_tick=..., + dispatcher_heartbeat_period=..., + node_cert_expiry=..., + external_cas=..., + name=..., + labels=..., + signing_ca_cert=..., + signing_ca_key=..., + ca_force_rotate=..., + autolock_managers=..., + log_driver=..., + ) -> None: ... + +class SwarmExternalCA(dict): + """Configuration for forwarding signing requests to an external + certificate authority. + + Args: + url (string): URL where certificate signing requests should be + sent. + protocol (string): Protocol for communication with the external CA. + options (dict): An object with key/value pairs that are interpreted + as protocol-specific options for the external CA driver. + ca_cert (string): The root CA certificate (in PEM format) this + external CA uses to issue TLS certificates (assumed to be to + the current swarm root CA certificate if not provided). + + + + """ + + def __init__(self, url, protocol=..., options=..., ca_cert=...) -> None: ... diff --git a/typings/docker/utils/__init__.pyi b/typings/docker/utils/__init__.pyi new file mode 100644 index 00000000..e69de29b diff --git a/typings/docker/utils/build.pyi b/typings/docker/utils/build.pyi new file mode 100644 index 00000000..2edae7be --- /dev/null +++ b/typings/docker/utils/build.pyi @@ -0,0 +1,31 @@ +"""This type stub file was generated by pyright.""" + +_SEP = ... + +def tar(path, exclude=..., dockerfile=..., fileobj=..., gzip=...): ... +def exclude_paths(root, patterns, dockerfile=...): # -> set[Unknown]: + """Given a root directory path and a list of .dockerignore patterns, return + an iterator of all paths (both regular files and directories) in the root + directory that do *not* match any of the patterns. + + All paths returned are relative to the root. + """ + ... + +def build_file_list(root): ... +def create_archive(root, files=..., fileobj=..., gzip=..., extra_files=...): ... +def mkbuildcontext(dockerfile): ... +def split_path(p): ... +def normalize_slashes(p): ... +def walk(root, patterns, default=...): ... + +class PatternMatcher: + def __init__(self, patterns) -> None: ... + def matches(self, filepath): ... + def walk(self, root): ... + +class Pattern: + def __init__(self, pattern_str) -> None: ... + @classmethod + def normalize(cls, p): ... + def match(self, filepath): ... diff --git a/typings/docker/utils/config.pyi b/typings/docker/utils/config.pyi new file mode 100644 index 00000000..682b7909 --- /dev/null +++ b/typings/docker/utils/config.pyi @@ -0,0 +1,15 @@ +"""This type stub file was generated by pyright.""" + +DOCKER_CONFIG_FILENAME = ... +LEGACY_DOCKER_CONFIG_FILENAME = ... +log = ... + +def find_config_file(config_path=...): ... +def config_path_from_environment(): ... +def home_dir(): # -> str: + """Get the user's home directory, using the same logic as the Docker Engine + client - use %USERPROFILE% on Windows, $HOME/getuid on POSIX. + """ + ... + +def load_general_config(config_path=...): ... diff --git a/typings/docker/utils/decorators.pyi b/typings/docker/utils/decorators.pyi new file mode 100644 index 00000000..6646f254 --- /dev/null +++ b/typings/docker/utils/decorators.pyi @@ -0,0 +1,5 @@ +"""This type stub file was generated by pyright.""" + +def check_resource(resource_name): ... +def minimum_version(version): ... +def update_headers(f): ... diff --git a/typings/docker/utils/fnmatch.pyi b/typings/docker/utils/fnmatch.pyi new file mode 100644 index 00000000..b1f9391c --- /dev/null +++ b/typings/docker/utils/fnmatch.pyi @@ -0,0 +1,47 @@ +"""This type stub file was generated by pyright.""" + +"""Filename matching with shell patterns. + +fnmatch(FILENAME, PATTERN) matches according to the local convention. +fnmatchcase(FILENAME, PATTERN) always takes case in account. + +The functions operate by translating the pattern into a regular +expression. They cache the compiled regular expressions for speed. + +The function translate(PATTERN) returns a regular expression +corresponding to PATTERN. (It does not compile it.) +""" +__all__ = ["fnmatch", "fnmatchcase", "translate"] +_cache = ... +_MAXCACHE = ... + +def fnmatch(name, pat): # -> bool: + """Test whether FILENAME matches PATTERN. + + Patterns are Unix shell style: + + * matches everything + ? matches any single character + [seq] matches any character in seq + [!seq] matches any char not in seq + + An initial period in FILENAME is not special. + Both FILENAME and PATTERN are first case-normalized + if the operating system requires it. + If you don't want this, use fnmatchcase(FILENAME, PATTERN). + """ + ... + +def fnmatchcase(name, pat): # -> bool: + """Test whether FILENAME matches PATTERN, including case. + This is a version of fnmatch() which doesn't case-normalize + its arguments. + """ + ... + +def translate(pat): # -> LiteralString | str: + """Translate a shell PATTERN to a regular expression. + + There is no way to quote meta-characters. + """ + ... diff --git a/typings/docker/utils/json_stream.pyi b/typings/docker/utils/json_stream.pyi new file mode 100644 index 00000000..e13b8ce8 --- /dev/null +++ b/typings/docker/utils/json_stream.pyi @@ -0,0 +1,34 @@ +"""This type stub file was generated by pyright.""" + +json_decoder = ... + +def stream_as_text(stream): # -> Generator[Unknown | str, Any, None]: + """Given a stream of bytes or text, if any of the items in the stream + are bytes convert them to text. + This function can be removed once we return text streams + instead of byte streams. + """ + ... + +def json_splitter(buffer): # -> tuple[Any, Unknown] | None: + """Attempt to parse a json object from a buffer. If there is at least one + object, return it and the rest of the buffer, otherwise return None. + """ + ... + +def json_stream(stream): # -> Generator[Any, Any, None]: + """Given a stream of text, return a stream of json objects. + This handles streams which are inconsistently buffered (some entries may + be newline delimited, and others are not). + """ + ... + +def line_splitter(buffer, separator=...): ... +def split_buffer(stream, splitter=..., decoder=...): # -> Generator[Unknown | str, Any, None]: + """Given a generator which yields strings and a splitter function, + joins all input, splits on the separator and yields each chunk. + Unlike string.split(), each chunk includes the trailing + separator, except for the last one if none was found on the end + of the input. + """ + ... diff --git a/typings/docker/utils/proxy.pyi b/typings/docker/utils/proxy.pyi new file mode 100644 index 00000000..0fddb181 --- /dev/null +++ b/typings/docker/utils/proxy.pyi @@ -0,0 +1,32 @@ +"""This type stub file was generated by pyright.""" + +class ProxyConfig(dict): + """Hold the client's proxy configuration.""" + + @property + def http(self): ... + @property + def https(self): ... + @property + def ftp(self): ... + @property + def no_proxy(self): ... + @staticmethod + def from_dict(config): # -> ProxyConfig: + """Instantiate a new ProxyConfig from a dictionary that represents a + client configuration, as described in `the documentation`_. + + .. _the documentation: + https://docs.docker.com/network/proxy/#configure-the-docker-client + """ + ... + def get_environment(self): # -> dict[Unknown, Unknown]: + """Return a dictionary representing the environment variables used to + set the proxy settings. + """ + ... + def inject_proxy_environment(self, environment): # -> list[Unknown | str]: + """Given a list of strings representing environment variables, prepend the + environment variables corresponding to the proxy settings. + """ + ... diff --git a/typings/docker/utils/socket.pyi b/typings/docker/utils/socket.pyi new file mode 100644 index 00000000..1acad8b2 --- /dev/null +++ b/typings/docker/utils/socket.pyi @@ -0,0 +1,67 @@ +"""This type stub file was generated by pyright.""" + +STDOUT = ... +STDERR = ... + +class SocketError(Exception): ... + +NPIPE_ENDED = ... + +def read(socket, n=...): + """Reads at most n bytes from socket.""" + ... + +def read_exactly(socket, n): # -> bytes: + """Reads exactly n bytes from socket + Raises SocketError if there isn't enough data. + """ + ... + +def next_frame_header(socket): # -> tuple[Literal[-1], Literal[-1]] | tuple[Any, Any]: + """Returns the stream and size of the next frame of data waiting to be read + from socket, according to the protocol defined here: + + https://docs.docker.com/engine/api/v1.24/#attach-to-a-container + """ + ... + +def frames_iter( + socket, tty +): # -> Generator[tuple[Literal[1], Unknown], None, None] | Generator[tuple[Any | Literal[-1], Unknown], Any, None]: + """Return a generator of frames read from socket. A frame is a tuple where + the first item is the stream number and the second item is a chunk of data. + + If the tty setting is enabled, the streams are multiplexed into the stdout + stream. + """ + ... + +def frames_iter_no_tty(socket): # -> Generator[tuple[Any | Literal[-1], Unknown], Any, None]: + """Returns a generator of data read from the socket when the tty setting is + not enabled. + """ + ... + +def frames_iter_tty(socket): # -> Generator[Unknown, Any, None]: + """Return a generator of data read from the socket when the tty setting is + enabled. + """ + ... + +def consume_socket_output(frames, demux=...): # -> bytes | tuple[None, ...]: + """Iterate through frames read from the socket and return the result. + + Args: + demux (bool): + If False, stdout and stderr are multiplexed, and the result is the + concatenation of all the frames. If True, the streams are + demultiplexed, and the result is a 2-tuple where each item is the + concatenation of frames belonging to the same stream. + """ + ... + +def demux_adaptor(stream_id, data): # -> tuple[Unknown, None] | tuple[None, Unknown]: + """Utility to demultiplex stdout and stderr when reading frames from the + socket. + """ + ... diff --git a/typings/docker/utils/utils.pyi b/typings/docker/utils/utils.pyi new file mode 100644 index 00000000..95e594ae --- /dev/null +++ b/typings/docker/utils/utils.pyi @@ -0,0 +1,48 @@ +"""This type stub file was generated by pyright.""" + +URLComponents = ... + +def create_ipam_pool(*args, **kwargs): ... +def create_ipam_config(*args, **kwargs): ... +def decode_json_header(header): ... +def compare_version(v1, v2): # -> Literal[0, -1, 1]: + """Compare docker versions. + + >>> v1 = '1.9' + >>> v2 = '1.10' + >>> compare_version(v1, v2) + 1 + >>> compare_version(v2, v1) + -1 + >>> compare_version(v2, v2) + 0 + """ + ... + +def version_lt(v1, v2): ... +def version_gte(v1, v2): ... +def convert_port_bindings(port_bindings): ... +def convert_volume_binds(binds): ... +def convert_tmpfs_mounts(tmpfs): ... +def convert_service_networks(networks): ... +def parse_repository_tag(repo_name): ... +def parse_host(addr, is_win32=..., tls=...): ... +def parse_devices(devices): ... +def kwargs_from_env(ssl_version=..., assert_hostname=..., environment=...): ... +def convert_filters(filters): ... +def datetime_to_timestamp(dt): + """Convert a UTC datetime to a Unix timestamp.""" + ... + +def parse_bytes(s): ... +def normalize_links(links): ... +def parse_env_file(env_file): # -> dict[Unknown, Unknown]: + """Reads a line-separated environment file. + The format of each line should be "key=value". + """ + ... + +def split_command(command): ... +def format_environment(environment): ... +def format_extra_hosts(extra_hosts, task=...): ... +def create_host_config(self, *args, **kwargs): ... diff --git a/typings/nbformat/__init__.pyi b/typings/nbformat/__init__.pyi new file mode 100644 index 00000000..4168e7f2 --- /dev/null +++ b/typings/nbformat/__init__.pyi @@ -0,0 +1 @@ +from . import v4 as v4 diff --git a/typings/nbformat/v4/__init__.pyi b/typings/nbformat/v4/__init__.pyi new file mode 100644 index 00000000..7d81cb7f --- /dev/null +++ b/typings/nbformat/v4/__init__.pyi @@ -0,0 +1,35 @@ +"""The main API for the v4 notebook format.""" +from .convert import downgrade +from .convert import upgrade +from .nbbase import nbformat +from .nbbase import nbformat_minor +from .nbbase import nbformat_schema +from .nbbase import new_code_cell +from .nbbase import new_markdown_cell +from .nbbase import new_notebook +from .nbbase import new_output +from .nbbase import new_raw_cell +from .nbbase import output_from_msg +from .nbjson import reads +from .nbjson import to_notebook +from .nbjson import writes + +__all__ = [ + "nbformat", + "nbformat_minor", + "nbformat_schema", + "new_code_cell", + "new_markdown_cell", + "new_raw_cell", + "new_notebook", + "new_output", + "output_from_msg", + "reads", + "writes", + "to_notebook", + "downgrade", + "upgrade", +] +reads_json = ... +writes_json = ... +to_notebook_json = ... diff --git a/typings/nbformat/v4/convert.pyi b/typings/nbformat/v4/convert.pyi new file mode 100644 index 00000000..99f4467c --- /dev/null +++ b/typings/nbformat/v4/convert.pyi @@ -0,0 +1,95 @@ +""" +This type stub file was generated by pyright. +""" + +"""Code for converting notebooks to and from v3.""" + +def upgrade(nb, from_version=..., from_minor=...): + """Convert a notebook to latest v4. + + Parameters + ---------- + nb : NotebookNode + The Python representation of the notebook to convert. + from_version : int + The original version of the notebook to convert. + from_minor : int + The original minor version of the notebook to convert (only relevant for v >= 3). + """ + ... + +def upgrade_cell(cell): + """upgrade a cell from v3 to v4 + + heading cell: + - -> markdown heading + code cell: + - remove language metadata + - cell.input -> cell.source + - cell.prompt_number -> cell.execution_count + - update outputs + """ + ... + +def downgrade_cell(cell): + """downgrade a cell from v4 to v3 + + code cell: + - set cell.language + - cell.input <- cell.source + - cell.prompt_number <- cell.execution_count + - update outputs + markdown cell: + - single-line heading -> heading cell + """ + ... + +_mime_map = ... + +def to_mime_key(d): + """convert dict with v3 aliases to plain mime-type keys""" + ... + +def from_mime_key(d): # -> dict[Unknown, Unknown]: + """convert dict with mime-type keys to v3 aliases""" + ... + +def upgrade_output(output): + """upgrade a single code cell output from v3 to v4 + + - pyout -> execute_result + - pyerr -> error + - output.type -> output.data.mime/type + - mime-type keys + - stream.stream -> stream.name + """ + ... + +def downgrade_output(output): + """downgrade a single code cell output to v3 from v4 + + - pyout <- execute_result + - pyerr <- error + - output.data.mime/type -> output.type + - un-mime-type keys + - stream.stream <- stream.name + """ + ... + +def upgrade_outputs(outputs): # -> list[Unknown]: + """upgrade outputs of a code cell from v3 to v4""" + ... + +def downgrade_outputs(outputs): # -> list[Unknown]: + """downgrade outputs of a code cell to v3 from v4""" + ... + +def downgrade(nb): + """Convert a v4 notebook to v3. + + Parameters + ---------- + nb : NotebookNode + The Python representation of the notebook to convert. + """ + ... diff --git a/typings/nbformat/v4/nbbase.pyi b/typings/nbformat/v4/nbbase.pyi new file mode 100644 index 00000000..4af4412d --- /dev/null +++ b/typings/nbformat/v4/nbbase.pyi @@ -0,0 +1,52 @@ +""" +This type stub file was generated by pyright. +""" + +"""Python API for composing notebook elements + +The Python representation of a notebook is a nested structure of +dictionary subclasses that support attribute access. +The functions in this module are merely helpers to build the structs +in the right form. +""" +from typing import Any + +import nbformat + +def validate(node, ref=...): # -> None: + """validate a v4 node""" + ... + +def new_output(output_type, data=..., **kwargs): # -> NotebookNode: + """Create a new output, to go in the ``cell.outputs`` list of a code cell.""" + ... + +def output_from_msg(msg): # -> NotebookNode: + """Create a NotebookNode for an output from a kernel's IOPub message. + + Returns + ------- + NotebookNode: the output as a notebook node. + + Raises + ------ + ValueError: if the message is not an output message. + + """ + ... + +def new_code_cell(source=..., **kwargs): # -> NotebookNode: + """Create a new code cell""" + ... + +def new_markdown_cell(source: str = ..., **kwargs: Any) -> nbformat.NotebookNode: + """Create a new markdown cell""" + ... + +def new_raw_cell(source=..., **kwargs): # -> NotebookNode: + """Create a new raw cell""" + ... + +def new_notebook(**kwargs): # -> NotebookNode: + """Create a new notebook""" + ... diff --git a/typings/nbformat/v4/nbjson.pyi b/typings/nbformat/v4/nbjson.pyi new file mode 100644 index 00000000..b2d49b37 --- /dev/null +++ b/typings/nbformat/v4/nbjson.pyi @@ -0,0 +1,45 @@ +""" +This type stub file was generated by pyright. +""" + +import json + +from .rwbase import NotebookReader +from .rwbase import NotebookWriter + +"""Read and write notebooks in JSON format.""" + +class BytesEncoder(json.JSONEncoder): + """A JSON encoder that accepts b64 (and other *ascii*) bytestrings.""" + + def default(self, obj): # -> str | Any: + """Get the default value of an object.""" + ... + +class JSONReader(NotebookReader): + """A JSON notebook reader.""" + + def reads(self, s, **kwargs): + """Read a JSON string into a Notebook object""" + ... + def to_notebook(self, d, **kwargs): + """Convert a disk-format notebook dict to in-memory NotebookNode + + handles multi-line values as strings, scrubbing of transient values, etc. + """ + ... + +class JSONWriter(NotebookWriter): + """A JSON notebook writer.""" + + def writes(self, nb, **kwargs): # -> str: + """Serialize a NotebookNode object as a JSON string""" + ... + +_reader = ... +_writer = ... +reads = ... +read = ... +to_notebook = ... +write = ... +writes = ... diff --git a/typings/nbformat/v4/rwbase.pyi b/typings/nbformat/v4/rwbase.pyi new file mode 100644 index 00000000..179a59ca --- /dev/null +++ b/typings/nbformat/v4/rwbase.pyi @@ -0,0 +1,56 @@ +""" +This type stub file was generated by pyright. +""" + +"""Base classes and utilities for readers and writers.""" + +def rejoin_lines(nb): + """rejoin multiline text into strings + + For reversing effects of ``split_lines(nb)``. + + This only rejoins lines that have been split, so if text objects were not split + they will pass through unchanged. + + Used when reading JSON files that may have been passed through split_lines. + """ + ... + +_non_text_split_mimes = ... + +def split_lines(nb): + """split likely multiline text into lists of strings + + For file output more friendly to line-based VCS. ``rejoin_lines(nb)`` will + reverse the effects of ``split_lines(nb)``. + + Used when writing JSON files. + """ + ... + +def strip_transient(nb): + """Strip transient values that shouldn't be stored in files. + + This should be called in *both* read and write. + """ + ... + +class NotebookReader: + """A class for reading notebooks.""" + + def reads(self, s, **kwargs): + """Read a notebook from a string.""" + ... + def read(self, fp, **kwargs): + """Read a notebook from a file like object""" + ... + +class NotebookWriter: + """A class for writing notebooks.""" + + def writes(self, nb, **kwargs): + """Write a notebook to a string.""" + ... + def write(self, nb, fp, **kwargs): + """Write a notebook to a file like object""" + ... diff --git a/typings/rsmiBindings.pyi b/typings/rsmiBindings.pyi index 3e3bf72b..e0a23119 100644 --- a/typings/rsmiBindings.pyi +++ b/typings/rsmiBindings.pyi @@ -1,10 +1,12 @@ # See https://github.com/RadeonOpenCompute/rocm_smi_lib/blob/master/python_smi_tools/rsmiBindings.py import ctypes +from typing import Any +from typing import Literal from typing import LiteralString class rocmsmi(ctypes.CDLL): @staticmethod - def rsmi_num_monitor_devices(num_devices: ctypes._CArgObject) -> LiteralString: ... + def rsmi_num_monitor_devices(num_devices: ctypes._CArgObject) -> Any: ... # Device ID dv_id: ctypes.c_uint64 = ... @@ -12,45 +14,28 @@ dv_id: ctypes.c_uint64 = ... gpu_id: ctypes.c_uint32 = ... # Policy enums -RSMI_MAX_NUM_FREQUENCIES = 32 -RSMI_MAX_FAN_SPEED = 255 -RSMI_NUM_VOLTAGE_CURVE_POINTS = 3 +RSMI_MAX_NUM_FREQUENCIES: Literal[32] = ... +RSMI_MAX_FAN_SPEED: Literal[255] = ... +RSMI_NUM_VOLTAGE_CURVE_POINTS: Literal[3] = ... class rsmi_status_t(ctypes.c_int): - RSMI_STATUS_SUCCESS = 0x0 - RSMI_STATUS_INVALID_ARGS = 0x1 - RSMI_STATUS_NOT_SUPPORTED = 0x2 - RSMI_STATUS_FILE_ERROR = 0x3 - RSMI_STATUS_PERMISSION = 0x4 - RSMI_STATUS_OUT_OF_RESOURCES = 0x5 - RSMI_STATUS_INTERNAL_EXCEPTION = 0x6 - RSMI_STATUS_INPUT_OUT_OF_BOUNDS = 0x7 - RSMI_STATUS_INIT_ERROR = 0x8 + RSMI_STATUS_SUCCESS: Literal[0x0] = ... + RSMI_STATUS_INVALID_ARGS: Literal[0x1] = ... + RSMI_STATUS_NOT_SUPPORTED: Literal[0x2] = ... + RSMI_STATUS_FILE_ERROR: Literal[0x3] = ... + RSMI_STATUS_PERMISSION: Literal[0x4] = ... + RSMI_STATUS_OUT_OF_RESOURCES: Literal[0x5] = ... + RSMI_STATUS_INTERNAL_EXCEPTION: Literal[0x6] = ... + RSMI_STATUS_INPUT_OUT_OF_BOUNDS: Literal[0x7] = ... + RSMI_STATUS_INIT_ERROR: Literal[0x8] = ... RSMI_INITIALIZATION_ERROR = RSMI_STATUS_INIT_ERROR - RSMI_STATUS_NOT_YET_IMPLEMENTED = 0x9 - RSMI_STATUS_NOT_FOUND = 0xA - RSMI_STATUS_INSUFFICIENT_SIZE = 0xB - RSMI_STATUS_INTERRUPT = 0xC - RSMI_STATUS_UNEXPECTED_SIZE = 0xD - RSMI_STATUS_NO_DATA = 0xE - RSMI_STATUS_UNKNOWN_ERROR = 0xFFFFFFFF + RSMI_STATUS_NOT_YET_IMPLEMENTED: Literal[0x9] = ... + RSMI_STATUS_NOT_FOUND: Literal[0xA] = ... + RSMI_STATUS_INSUFFICIENT_SIZE: Literal[0xB] = ... + RSMI_STATUS_INTERRUPT: Literal[0xC] = ... + RSMI_STATUS_UNEXPECTED_SIZE: Literal[0xD] = ... + RSMI_STATUS_NO_DATA: Literal[0xE] = ... + RSMI_STATUS_UNKNOWN_ERROR: Literal[0xFFFFFFFF] = ... # Dictionary of rsmi ret codes and it's verbose output -rsmi_status_verbose_err_out = { - rsmi_status_t.RSMI_STATUS_SUCCESS: "Operation was successful", - rsmi_status_t.RSMI_STATUS_INVALID_ARGS: "Invalid arguments provided", - rsmi_status_t.RSMI_STATUS_NOT_SUPPORTED: "Not supported on the given system", - rsmi_status_t.RSMI_STATUS_FILE_ERROR: "Problem accessing a file", - rsmi_status_t.RSMI_STATUS_PERMISSION: "Permission denied", - rsmi_status_t.RSMI_STATUS_OUT_OF_RESOURCES: "Unable to acquire memory or other resource", - rsmi_status_t.RSMI_STATUS_INTERNAL_EXCEPTION: "An internal exception was caught", - rsmi_status_t.RSMI_STATUS_INPUT_OUT_OF_BOUNDS: "Provided input is out of allowable or safe range", - rsmi_status_t.RSMI_INITIALIZATION_ERROR: "Error occured during rsmi initialization", - rsmi_status_t.RSMI_STATUS_NOT_YET_IMPLEMENTED: "Requested function is not implemented on this setup", - rsmi_status_t.RSMI_STATUS_NOT_FOUND: "Item searched for but not found", - rsmi_status_t.RSMI_STATUS_INSUFFICIENT_SIZE: "Insufficient resources available", - rsmi_status_t.RSMI_STATUS_INTERRUPT: "Interrupt occured during execution", - rsmi_status_t.RSMI_STATUS_UNEXPECTED_SIZE: "Unexpected amount of data read", - rsmi_status_t.RSMI_STATUS_NO_DATA: "No data found for the given input", - rsmi_status_t.RSMI_STATUS_UNKNOWN_ERROR: "Unknown error occured", -} +rsmi_status_verbose_err_out: dict[LiteralString, LiteralString] = ...