From 97d76eec85d6982fb32c785c48b951be322379cc Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Thu, 23 May 2024 10:02:23 -0400 Subject: [PATCH] tests: add additional basic testing (#982) * chore: update rebase tests Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * chore: update partial clients before removing Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * fix: update clients parsing logics to work with 0.5 Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * chore: ignore ci runs as to run locally Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * chore: update async client tests Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * chore: update pre-commit Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> --- .github/workflows/ci.yml | 290 +++++++++--------- .pre-commit-config.yaml | 2 +- .python-version-default | 2 +- examples/api_server.py | 4 +- examples/async_client.py | 4 + examples/openai_chat_completion_client.py | 2 +- hatch.toml | 13 +- openllm-client/pyproject.toml | 2 +- openllm-client/src/openllm_client/_http.py | 88 +++--- openllm-client/src/openllm_client/_schemas.py | 63 +--- openllm-client/src/openllm_client/_shim.py | 184 ++++++----- openllm-client/src/openllm_client/_stream.py | 95 +++--- .../src/openllm_client/_typing_compat.py | 1 + openllm-python/README.md | 2 +- openllm-python/pyproject.toml | 39 --- openllm-python/src/openllm/__main__.py | 2 +- .../src/openllm/entrypoints/openai.py | 2 +- openllm-python/tests/__init__.py | 9 - openllm-python/tests/_data.py | 78 +++++ openllm-python/tests/_strategies/__init__.py | 0 .../tests/_strategies/_configuration.py | 60 ---- openllm-python/tests/configuration_test.py | 152 --------- openllm-python/tests/conftest.py | 50 +-- openllm-python/tests/models_test.py | 29 -- openllm-python/tests/package_test.py | 60 ---- openllm-python/tests/regression_test.py | 56 ++++ openllm-python/tests/strategies_test.py | 185 ----------- pyproject.toml | 2 +- 28 files changed, 498 insertions(+), 978 deletions(-) create mode 100644 examples/async_client.py create mode 100644 openllm-python/tests/_data.py delete mode 100644 openllm-python/tests/_strategies/__init__.py delete mode 100644 openllm-python/tests/_strategies/_configuration.py delete mode 100644 openllm-python/tests/configuration_test.py delete mode 100644 openllm-python/tests/models_test.py delete mode 100644 openllm-python/tests/package_test.py create mode 100644 openllm-python/tests/regression_test.py delete mode 100644 openllm-python/tests/strategies_test.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 82e96b59..df3259eb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,146 +1,144 @@ -name: Continuous Integration -on: - workflow_call: - push: - branches: [main] - paths-ignore: - - 'docs/**' - - 'bazel/**' - - 'typings/**' - - '*.md' - - 'changelog.d/**' - - 'assets/**' - pull_request: - branches: [main] - paths-ignore: - - 'docs/**' - - 'bazel/**' - - 'typings/**' - - '*.md' - - 'changelog.d/**' - - 'assets/**' -env: - LINES: 120 - COLUMNS: 120 - OPENLLM_DO_NOT_TRACK: True - PYTHONUNBUFFERED: '1' - HATCH_VERBOSE: 2 -# https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun -defaults: - run: - shell: bash --noprofile --norc -exo pipefail {0} -jobs: - tests: - runs-on: ${{ matrix.os }} - if: ${{ github.event_name == 'pull_request' || github.event_name == 'push'|| github.event_name == 'workflow_call' }} - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ['3.8', '3.11'] - exclude: - - os: 'windows-latest' - name: tests (${{ matrix.python-version }}.${{ matrix.os }}) - steps: - - uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # ratchet:actions/checkout@v4.1.6 - with: - fetch-depth: 0 - ref: ${{ github.event.pull_request.head.sha }} - - uses: bentoml/setup-bentoml-action@862aa8fa0e0c3793fcca4bfe7a62717a497417e4 # ratchet:bentoml/setup-bentoml-action@v1 - with: - bentoml-version: 'main' - python-version: ${{ matrix.python-version }} - - name: Run tests - run: hatch run tests:python - - name: Disambiguate coverage filename - run: mv .coverage ".coverage.${{ matrix.os }}.${{ matrix.python-version }}" - - name: Upload coverage data - uses: actions/upload-artifact@a8a3f3ad30e3422c9c7b888a15615d19a852ae32 # ratchet:actions/upload-artifact@v3 - with: - name: coverage-data - path: .coverage.* - coverage: - name: report-coverage - runs-on: ubuntu-latest - if: false - needs: tests - steps: - - uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # ratchet:actions/checkout@v4.1.6 - with: - fetch-depth: 0 - ref: ${{ github.event.pull_request.head.sha }} - - uses: bentoml/setup-bentoml-action@862aa8fa0e0c3793fcca4bfe7a62717a497417e4 # ratchet:bentoml/setup-bentoml-action@v1 - with: - bentoml-version: 'main' - python-version-file: .python-version-default - - name: Download coverage data - uses: actions/download-artifact@9bc31d5ccc31df68ecc42ccf4149144866c47d8a # ratchet:actions/download-artifact@v3 - with: - name: coverage-data - - name: Combine coverage data - run: hatch run coverage:combine - - name: Export coverage reports - run: | - hatch run coverage:report-xml openllm-python - hatch run coverage:report-uncovered-html openllm-python - - name: Upload uncovered HTML report - uses: actions/upload-artifact@a8a3f3ad30e3422c9c7b888a15615d19a852ae32 # ratchet:actions/upload-artifact@v3 - with: - name: uncovered-html-report - path: htmlcov - - name: Generate coverage summary - run: hatch run coverage:generate-summary - - name: Write coverage summary report - if: github.event_name == 'pull_request' - run: hatch run coverage:write-summary-report - - name: Update coverage pull request comment - if: github.event_name == 'pull_request' && !github.event.pull_request.head.repo.fork - uses: marocchino/sticky-pull-request-comment@331f8f5b4215f0445d3c07b4967662a32a2d3e31 # ratchet:marocchino/sticky-pull-request-comment@v2 - with: - path: coverage-report.md - cli-benchmark: - name: Check for CLI responsiveness - runs-on: ubuntu-latest - env: - HYPERFINE_VERSION: '1.12.0' - steps: - - uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # ratchet:actions/checkout@v4.1.6 - with: - fetch-depth: 0 - - name: Install hyperfine - run: | - wget https://github.com/sharkdp/hyperfine/releases/download/v${HYPERFINE_VERSION}/hyperfine_${HYPERFINE_VERSION}_amd64.deb - sudo dpkg -i hyperfine_${HYPERFINE_VERSION}_amd64.deb - - uses: bentoml/setup-bentoml-action@862aa8fa0e0c3793fcca4bfe7a62717a497417e4 # ratchet:bentoml/setup-bentoml-action@v1 - with: - bentoml-version: 'main' - python-version-file: .python-version-default - - name: Install self - run: bash local.sh - - name: Speed - run: hyperfine -m 100 --warmup 10 openllm - brew-dry-run: - name: Running dry-run tests for brew - runs-on: macos-latest - steps: - - name: Install tap and dry-run - run: | - brew tap bentoml/openllm https://github.com/bentoml/openllm - brew install openllm - openllm --help - openllm models --show-available - evergreen: # https://github.com/marketplace/actions/alls-green#why - if: always() - needs: - - tests - - cli-benchmark - - brew-dry-run - runs-on: ubuntu-latest - steps: - - name: Decide whether the needed jobs succeeded or failed - uses: re-actors/alls-green@05ac9388f0aebcb5727afa17fcccfecd6f8ec5fe # ratchet:re-actors/alls-green@release/v1 - with: - jobs: ${{ toJSON(needs) }} -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} - cancel-in-progress: true +# name: Continuous Integration +# on: +# workflow_call: +# push: +# branches: [main] +# paths-ignore: +# - 'docs/**' +# - 'bazel/**' +# - 'typings/**' +# - '*.md' +# - 'changelog.d/**' +# - 'assets/**' +# pull_request: +# branches: [main] +# paths-ignore: +# - 'docs/**' +# - 'bazel/**' +# - 'typings/**' +# - '*.md' +# - 'changelog.d/**' +# - 'assets/**' +# env: +# LINES: 120 +# COLUMNS: 120 +# OPENLLM_DO_NOT_TRACK: True +# PYTHONUNBUFFERED: '1' +# HATCH_VERBOSE: 2 +# # https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun +# defaults: +# run: +# shell: bash --noprofile --norc -exo pipefail {0} +# jobs: +# tests: +# runs-on: ${{ matrix.os }} +# if: ${{ github.event_name == 'pull_request' || github.event_name == 'push'|| github.event_name == 'workflow_call' }} +# strategy: +# fail-fast: false +# matrix: +# os: [ubuntu-latest] +# python-version: ['3.9', '3.12'] +# name: tests (${{ matrix.python-version }}.${{ matrix.os }}) +# steps: +# - uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # ratchet:actions/checkout@v4.1.6 +# with: +# fetch-depth: 0 +# ref: ${{ github.event.pull_request.head.sha }} +# - uses: bentoml/setup-bentoml-action@862aa8fa0e0c3793fcca4bfe7a62717a497417e4 # ratchet:bentoml/setup-bentoml-action@v1 +# with: +# bentoml-version: 'main' +# python-version: ${{ matrix.python-version }} +# - name: Run tests +# run: hatch run tests:python +# - name: Disambiguate coverage filename +# run: mv .coverage ".coverage.${{ matrix.os }}.${{ matrix.python-version }}" +# - name: Upload coverage data +# uses: actions/upload-artifact@a8a3f3ad30e3422c9c7b888a15615d19a852ae32 # ratchet:actions/upload-artifact@v3 +# with: +# name: coverage-data +# path: .coverage.* +# coverage: +# name: report-coverage +# runs-on: ubuntu-latest +# if: false +# needs: tests +# steps: +# - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4.1.1 +# with: +# fetch-depth: 0 +# ref: ${{ github.event.pull_request.head.sha }} +# - uses: bentoml/setup-bentoml-action@862aa8fa0e0c3793fcca4bfe7a62717a497417e4 # ratchet:bentoml/setup-bentoml-action@v1 +# with: +# bentoml-version: 'main' +# python-version-file: .python-version-default +# - name: Download coverage data +# uses: actions/download-artifact@9bc31d5ccc31df68ecc42ccf4149144866c47d8a # ratchet:actions/download-artifact@v3 +# with: +# name: coverage-data +# - name: Combine coverage data +# run: hatch run coverage:combine +# - name: Export coverage reports +# run: | +# hatch run coverage:report-xml openllm-python +# hatch run coverage:report-uncovered-html openllm-python +# - name: Upload uncovered HTML report +# uses: actions/upload-artifact@a8a3f3ad30e3422c9c7b888a15615d19a852ae32 # ratchet:actions/upload-artifact@v3 +# with: +# name: uncovered-html-report +# path: htmlcov +# - name: Generate coverage summary +# run: hatch run coverage:generate-summary +# - name: Write coverage summary report +# if: github.event_name == 'pull_request' +# run: hatch run coverage:write-summary-report +# - name: Update coverage pull request comment +# if: github.event_name == 'pull_request' && !github.event.pull_request.head.repo.fork +# uses: marocchino/sticky-pull-request-comment@331f8f5b4215f0445d3c07b4967662a32a2d3e31 # ratchet:marocchino/sticky-pull-request-comment@v2 +# with: +# path: coverage-report.md +# cli-benchmark: +# name: Check for CLI responsiveness +# runs-on: ubuntu-latest +# env: +# HYPERFINE_VERSION: '1.12.0' +# steps: +# - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4.1.1 +# with: +# fetch-depth: 0 +# - name: Install hyperfine +# run: | +# wget https://github.com/sharkdp/hyperfine/releases/download/v${HYPERFINE_VERSION}/hyperfine_${HYPERFINE_VERSION}_amd64.deb +# sudo dpkg -i hyperfine_${HYPERFINE_VERSION}_amd64.deb +# - uses: bentoml/setup-bentoml-action@862aa8fa0e0c3793fcca4bfe7a62717a497417e4 # ratchet:bentoml/setup-bentoml-action@v1 +# with: +# bentoml-version: 'main' +# python-version-file: .python-version-default +# - name: Install self +# run: bash local.sh +# - name: Speed +# run: hyperfine -m 100 --warmup 10 openllm +# brew-dry-run: +# name: Running dry-run tests for brew +# runs-on: macos-latest +# steps: +# - name: Install tap and dry-run +# run: | +# brew tap bentoml/openllm https://github.com/bentoml/openllm +# brew install openllm +# openllm --help +# openllm models --show-available +# evergreen: # https://github.com/marketplace/actions/alls-green#why +# if: always() +# needs: +# - tests +# # - cli-benchmark +# # - brew-dry-run +# runs-on: ubuntu-latest +# steps: +# - name: Decide whether the needed jobs succeeded or failed +# uses: re-actors/alls-green@05ac9388f0aebcb5727afa17fcccfecd6f8ec5fe # ratchet:re-actors/alls-green@release/v1 +# with: +# jobs: ${{ toJSON(needs) }} +# concurrency: +# group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} +# cancel-in-progress: true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6134d130..ed48ce24 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ ci: autoupdate_commit_msg: 'ci: pre-commit autoupdate [pre-commit.ci]' autofix_prs: false default_language_version: - python: python3.9 # NOTE: sync with .python-version-default + python: python3.11 # NOTE: sync with .python-version-default exclude: '.*\.(css|js|svg)$' repos: - repo: https://github.com/astral-sh/ruff-pre-commit diff --git a/.python-version-default b/.python-version-default index bd28b9c5..2c073331 100644 --- a/.python-version-default +++ b/.python-version-default @@ -1 +1 @@ -3.9 +3.11 diff --git a/examples/api_server.py b/examples/api_server.py index 6e1ecb47..e27c5b89 100644 --- a/examples/api_server.py +++ b/examples/api_server.py @@ -1,11 +1,13 @@ from __future__ import annotations -import uuid +import uuid, os from typing import Any, AsyncGenerator, Dict, TypedDict, Union from bentoml import Service from bentoml.io import JSON, Text from openllm import LLM +os.environ['IMPLEMENTATION'] = 'deprecated' + llm = LLM[Any, Any]('HuggingFaceH4/zephyr-7b-alpha', backend='vllm') diff --git a/examples/async_client.py b/examples/async_client.py new file mode 100644 index 00000000..a2550064 --- /dev/null +++ b/examples/async_client.py @@ -0,0 +1,4 @@ +import openllm, asyncio +client = openllm.AsyncHTTPClient('http://0.0.0.0:3000') +async def main(): assert (await client.health()); print(await client.generate('Explain superconductor to a 5 year old.')) +if __name__ == "__main__": asyncio.run(main()) diff --git a/examples/openai_chat_completion_client.py b/examples/openai_chat_completion_client.py index dced50ff..e19d4224 100644 --- a/examples/openai_chat_completion_client.py +++ b/examples/openai_chat_completion_client.py @@ -14,7 +14,7 @@ model = models.data[0].id # Chat completion API stream = str(os.getenv('STREAM', False)).upper() in ['TRUE', '1', 'YES', 'Y', 'ON'] completions = client.chat.completions.create(messages=[ - ChatCompletionSystemMessageParam(role='system', content='You will be the writing assistant that assume the ton of Ernest Hemmingway.'), + ChatCompletionSystemMessageParam(role='system', content='You will be the writing assistant that assume the tone of Ernest Hemmingway.'), ChatCompletionUserMessageParam(role='user', content='Write an essay on Nietzsche and absurdism.'), ], model=model, max_tokens=1024, stream=stream) diff --git a/hatch.toml b/hatch.toml index 2b7bfda3..53ada4dd 100644 --- a/hatch.toml +++ b/hatch.toml @@ -1,4 +1,6 @@ [envs.default] +installer = "uv" +type = "virtual" dependencies = [ "openllm-core @ {root:uri}/openllm-core", "openllm-client @ {root:uri}/openllm-client", @@ -29,12 +31,12 @@ setup = [ quality = ["bash ./all.sh", "- pre-commit run --all-files", "- pnpm format"] tool = ["quality", "bash ./clean.sh", 'python ./cz.py'] [envs.tests] +installer = "uv" +type = "virtual" dependencies = [ + "openllm[vllm] @ {root:uri}/openllm-python", "openllm-core @ {root:uri}/openllm-core", "openllm-client @ {root:uri}/openllm-client", - "openllm[chatglm,fine-tune] @ {root:uri}/openllm-python", - # NOTE: interact with docker for container tests. - "docker", # NOTE: Tests strategies with Hypothesis and pytest, and snapshot testing with syrupy "coverage[toml]>=6.5", "filelock>=3.7.1", @@ -53,10 +55,7 @@ skip-install = false template = "tests" [envs.tests.scripts] _run_script = "pytest --cov --cov-report={env:COVERAGE_REPORT:term-missing} --cov-config=pyproject.toml -vv" -distributed = "_run_script --reruns 5 --reruns-delay 3 --ignore openllm-python/tests/models -n 3 -r aR {args:openllm-python/tests}" -models = "_run_script -s {args:openllm-python/tests/models}" -python = "_run_script --reruns 5 --reruns-delay 3 --ignore openllm-python/tests/models -r aR {args:openllm-python/tests}" -snapshot-models = "_run_script -s --snapshot-update {args:openllm-python/tests/models}" +python = "_run_script -r aR -x {args:openllm-python/tests}" [envs.tests.overrides] env.GITHUB_ACTIONS.env-vars = "COVERAGE_REPORT=" [envs.coverage] diff --git a/openllm-client/pyproject.toml b/openllm-client/pyproject.toml index 61e41e45..6e68f738 100644 --- a/openllm-client/pyproject.toml +++ b/openllm-client/pyproject.toml @@ -52,7 +52,7 @@ keywords = [ "PyTorch", "Transformers", ] -dependencies = ["openllm-core", "attrs>=23.2.0", "httpx", "distro", "anyio"] +dependencies = ["openllm-core", "httpx", "distro", "anyio"] license = "Apache-2.0" name = "openllm-client" requires-python = ">=3.8" diff --git a/openllm-client/src/openllm_client/_http.py b/openllm-client/src/openllm_client/_http.py index ed802a67..1142d95f 100644 --- a/openllm-client/src/openllm_client/_http.py +++ b/openllm-client/src/openllm_client/_http.py @@ -1,10 +1,6 @@ from __future__ import annotations -import importlib.metadata -import logging -import os -import typing as t -import attr +import importlib.metadata, logging, os, typing as t, orjson, pydantic from ._schemas import Helpers, Metadata, Response, StreamingResponse from ._shim import MAX_RETRIES, AsyncClient, Client @@ -18,28 +14,31 @@ def _address_converter(addr: str): return addr if '://' in addr else 'http://' + addr -@attr.define(init=False) class HTTPClient(Client): - helpers: Helpers = attr.field(init=False) - _api_version: str = 'v1' - _verify: bool = True - __metadata: Metadata | None = None - __config: dict[str, t.Any] | None = None + _helpers: Helpers = pydantic.PrivateAttr() + _api_version: str = pydantic.PrivateAttr(default='v1') + _verify: bool = pydantic.PrivateAttr(default=True) + __metadata: t.Optional[Metadata] = None + __config: t.Optional[t.Dict[str, t.Any]] = None def __repr__(self): - return f'' + return f'' def __init__(self, address=None, timeout=30, verify=False, max_retries=MAX_RETRIES, api_version='v1'): if address is None: address = os.getenv('OPENLLM_ENDPOINT') if address is None: raise ValueError("address must either be provided or through 'OPENLLM_ENDPOINT'") - self._api_version, self._verify = api_version, verify - - self.helpers = Helpers(client=self) super().__init__(_address_converter(address), VERSION, timeout=timeout, max_retries=max_retries) + self._helpers = Helpers(client=self) + self._api_version, self._verify = api_version, verify + + @property + def helpers(self): + return self._helpers + def _build_auth_headers(self) -> t.Dict[str, str]: env = os.getenv('OPENLLM_AUTH_TOKEN') if env is not None: @@ -57,13 +56,13 @@ class HTTPClient(Client): def _metadata(self): if self.__metadata is None: path = f'/{self._api_version}/metadata' - self.__metadata = self._post(path, response_cls=Metadata, json={}, options={'max_retries': self._max_retries}) + self.__metadata = self._post(path, response_cls=Metadata, json={}, options={'max_retries': self.max_retries}) return self.__metadata @property def _config(self) -> dict[str, t.Any]: if self.__config is None: - self.__config = self._metadata.configuration + self.__config = orjson.loads(self._metadata.configuration) return self.__config def query(self, prompt, **attrs): @@ -71,7 +70,7 @@ class HTTPClient(Client): def health(self): response = self._get( - '/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries} + '/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self.max_retries} ) return response.status_code == 200 @@ -79,7 +78,7 @@ class HTTPClient(Client): self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs ) -> Response: if timeout is None: - timeout = self._timeout + timeout = self.timeout if verify is None: verify = self._verify # XXX: need to support this again if llm_config is not None: @@ -91,7 +90,7 @@ class HTTPClient(Client): f'/{self._api_version}/generate', response_cls=Response, json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name), - options={'max_retries': self._max_retries}, + options={'max_retries': self.max_retries}, ) def generate_stream( @@ -104,7 +103,7 @@ class HTTPClient(Client): self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs ) -> t.Iterator[Response]: if timeout is None: - timeout = self._timeout + timeout = self.timeout if verify is None: verify = self._verify # XXX: need to support this again if llm_config is not None: @@ -115,34 +114,37 @@ class HTTPClient(Client): f'/{self._api_version}/generate_stream', response_cls=Response, json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name), - options={'max_retries': self._max_retries}, + options={'max_retries': self.max_retries}, stream=True, ) -@attr.define(init=False) -class AsyncHTTPClient(AsyncClient): - helpers: Helpers = attr.field(init=False) - _api_version: str = 'v1' - _verify: bool = True - __metadata: Metadata | None = None - __config: dict[str, t.Any] | None = None +class AsyncHTTPClient(AsyncClient, pydantic.BaseModel): + _helpers: Helpers = pydantic.PrivateAttr() + _api_version: str = pydantic.PrivateAttr(default='v1') + _verify: bool = pydantic.PrivateAttr(default=True) + __metadata: t.Optional[Metadata] = None + __config: t.Optional[t.Dict[str, t.Any]] = None def __repr__(self): - return f'' + return f'' def __init__(self, address=None, timeout=30, verify=False, max_retries=MAX_RETRIES, api_version='v1'): if address is None: address = os.getenv('OPENLLM_ENDPOINT') if address is None: raise ValueError("address must either be provided or through 'OPENLLM_ENDPOINT'") - self._api_version, self._verify = api_version, verify - - # mk messages to be async here - self.helpers = Helpers.permute(messages=Helpers.async_messages)(async_client=self) super().__init__(_address_converter(address), VERSION, timeout=timeout, max_retries=max_retries) + # mk messages to be async here + self._helpers = Helpers.permute(messages=Helpers.async_messages)(async_client=self) + self._api_version, self._verify = api_version, verify + + @property + def helpers(self): + return self._helpers + def _build_auth_headers(self) -> t.Dict[str, str]: env = os.getenv('OPENLLM_AUTH_TOKEN') if env is not None: @@ -153,14 +155,14 @@ class AsyncHTTPClient(AsyncClient): async def _metadata(self) -> t.Awaitable[Metadata]: if self.__metadata is None: self.__metadata = await self._post( - f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries} + f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self.max_retries} ) return self.__metadata @property async def _config(self): if self.__config is None: - self.__config = (await self._metadata).configuration + self.__config = orjson.loads((await self._metadata).configuration) return self.__config async def query(self, prompt, **attrs): @@ -168,7 +170,7 @@ class AsyncHTTPClient(AsyncClient): async def health(self): response = await self._get( - '/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries} + '/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self.max_retries} ) return response.status_code == 200 @@ -176,7 +178,7 @@ class AsyncHTTPClient(AsyncClient): self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs ) -> Response: if timeout is None: - timeout = self._timeout + timeout = self.timeout if verify is None: verify = self._verify # XXX: need to support this again _metadata = await self._metadata @@ -189,22 +191,22 @@ class AsyncHTTPClient(AsyncClient): f'/{self._api_version}/generate', response_cls=Response, json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name), - options={'max_retries': self._max_retries}, + options={'max_retries': self.max_retries}, ) async def generate_stream( - self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs + self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, index=0, **attrs ) -> t.AsyncGenerator[StreamingResponse, t.Any]: async for response_chunk in self.generate_iterator( prompt, llm_config, stop, adapter_name, timeout, verify, **attrs ): - yield StreamingResponse.from_response_chunk(response_chunk) + yield StreamingResponse.from_response_chunk(response_chunk, index=index) async def generate_iterator( self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs ) -> t.AsyncGenerator[Response, t.Any]: if timeout is None: - timeout = self._timeout + timeout = self.timeout if verify is None: verify = self._verify # XXX: need to support this again _metadata = await self._metadata @@ -218,7 +220,7 @@ class AsyncHTTPClient(AsyncClient): f'/{self._api_version}/generate_stream', response_cls=Response, json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name), - options={'max_retries': self._max_retries}, + options={'max_retries': self.max_retries}, stream=True, ): yield response_chunk diff --git a/openllm-client/src/openllm_client/_schemas.py b/openllm-client/src/openllm_client/_schemas.py index 1723d583..aff47cca 100644 --- a/openllm-client/src/openllm_client/_schemas.py +++ b/openllm-client/src/openllm_client/_schemas.py @@ -1,17 +1,13 @@ from __future__ import annotations -import types -import typing as t -import attr -import orjson +import types, pydantic, typing as t from openllm_core._schemas import ( CompletionChunk as CompletionChunk, - GenerationOutput as Response, # backward compatibility - _SchemaMixin as _SchemaMixin, + MetadataOutput as Metadata, + GenerationOutput as Response, ) - -from ._utils import converter +from ._typing_compat import TypedDict if t.TYPE_CHECKING: from ._shim import AsyncClient, Client @@ -20,65 +16,28 @@ if t.TYPE_CHECKING: __all__ = ['CompletionChunk', 'Helpers', 'Metadata', 'Response', 'StreamingResponse'] -@attr.define -class Metadata(_SchemaMixin): - """NOTE: Metadata is a modified version of the original MetadataOutput from openllm-core. - - The configuration is now structured into a dictionary for easy of use.""" - - model_id: str - timeout: int - model_name: str - backend: str - configuration: t.Dict[str, t.Any] - - -def _structure_metadata(data: t.Dict[str, t.Any], cls: type[Metadata]) -> Metadata: - try: - configuration = orjson.loads(data['configuration']) - generation_config = configuration.pop('generation_config') - configuration = {**configuration, **generation_config} - except orjson.JSONDecodeError as e: - raise RuntimeError(f'Malformed metadata configuration (Server-side issue): {e}') from None - try: - return cls( - model_id=data['model_id'], - timeout=data['timeout'], - model_name=data['model_name'], - backend=data['backend'], - configuration=configuration, - ) - except Exception as e: - raise RuntimeError(f'Malformed metadata (Server-side issue): {e}') from None - - -converter.register_structure_hook(Metadata, _structure_metadata) - - -@attr.define -class StreamingResponse(_SchemaMixin): +class StreamingResponse(pydantic.BaseModel): request_id: str index: int text: str token_ids: int @classmethod - def from_response_chunk(cls, response: Response) -> StreamingResponse: + def from_response_chunk(cls, response: Response, index: int = 0) -> StreamingResponse: return cls( request_id=response.request_id, - index=response.outputs[0].index, - text=response.outputs[0].text, - token_ids=response.outputs[0].token_ids[0], + index=response.outputs[index].index, + text=response.outputs[index].text, + token_ids=response.outputs[index].token_ids[0], ) -class MesssageParam(t.TypedDict): +class MesssageParam(TypedDict): role: t.Literal['user', 'system', 'assistant'] content: str -@attr.define(repr=False) -class Helpers: +class Helpers(pydantic.BaseModel): _client: t.Optional[Client] = None _async_client: t.Optional[AsyncClient] = None diff --git a/openllm-client/src/openllm_client/_shim.py b/openllm-client/src/openllm_client/_shim.py index b5499f5a..02c05298 100644 --- a/openllm-client/src/openllm_client/_shim.py +++ b/openllm-client/src/openllm_client/_shim.py @@ -1,22 +1,11 @@ # This provides a base shim with httpx and acts as base request from __future__ import annotations -import asyncio -import email.utils -import logging -import platform -import random -import time -import typing as t +import asyncio, logging, platform, random, time, typing as t +import email.utils, anyio, distro, httpx, pydantic -import anyio -import attr -import distro -import httpx - -from ._stream import AsyncStream, Response, Stream +from ._stream import AsyncStream, Stream, Response from ._typing_compat import Annotated, Architecture, LiteralString, Platform -from ._utils import converter logger = logging.getLogger(__name__) @@ -92,15 +81,16 @@ def _architecture() -> Architecture: @t.final -@attr.frozen(auto_attribs=True) -class RequestOptions: - method: str = attr.field(converter=str.lower) +class RequestOptions(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra='forbid', protected_namespaces=()) + + method: pydantic.constr(to_lower=True) url: str - json: t.Optional[t.Dict[str, t.Any]] = attr.field(default=None) - params: t.Optional[t.Mapping[str, t.Any]] = attr.field(default=None) - headers: t.Optional[t.Dict[str, str]] = attr.field(default=None) - max_retries: int = attr.field(default=MAX_RETRIES) - return_raw_response: bool = attr.field(default=False) + json: t.Optional[t.Dict[str, t.Any]] = pydantic.Field(default=None) + params: t.Optional[t.Mapping[str, t.Any]] = pydantic.Field(default=None) + headers: t.Optional[t.Dict[str, str]] = pydantic.Field(default=None) + max_retries: int = pydantic.Field(default=MAX_RETRIES) + return_raw_response: bool = pydantic.Field(default=False) def get_max_retries(self, max_retries: int | None) -> int: return max_retries if max_retries is not None else self.max_retries @@ -111,95 +101,97 @@ class RequestOptions: @t.final -@attr.frozen(auto_attribs=True) -class APIResponse(t.Generic[Response]): - _raw_response: httpx.Response - _client: BaseClient[t.Any, t.Any] - _response_cls: type[Response] - _stream: bool - _stream_cls: t.Optional[t.Union[t.Type[Stream[t.Any]], t.Type[AsyncStream[t.Any]]]] - _options: RequestOptions +class APIResponse(pydantic.BaseModel, t.Generic[Response]): + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) - _parsed: t.Optional[Response] = attr.field(default=None, repr=False) + raw_response: httpx.Response + client: t.Union[AsyncClient, Client] + response_cls: t.Optional[type[Response]] + stream: bool + stream_cls: t.Optional[t.Union[t.Type[Stream[t.Any]], t.Type[AsyncStream[t.Any]]]] + options: RequestOptions + _parsed: t.Optional[Response] = pydantic.PrivateAttr(default=None) def parse(self): - if self._options.return_raw_response: - return self._raw_response + if self.options.return_raw_response: + return self.raw_response + + if self.response_cls is None: + raise ValueError('Response class cannot be None.') + if self._parsed is not None: return self._parsed - if self._stream: - stream_cls = self._stream_cls or self._client._default_stream_cls - return stream_cls(response_cls=self._response_cls, response=self._raw_response, client=self._client) - if self._response_cls is str: - return self._raw_response.text + if self.stream: + stream_cls = self.stream_cls or self.client.default_stream_cls + return stream_cls(response_cls=self.response_cls, response=self.raw_response, client=self.client) - content_type, *_ = self._raw_response.headers.get('content-type', '').split(';') + if self.response_cls is str: + return self.raw_response.text + + content_type, *_ = self.raw_response.headers.get('content-type', '').split(';') if content_type != 'application/json': # Since users specific different content_type, then we return the raw binary text without and deserialisation - return self._raw_response.text + return self.raw_response.text - data = self._raw_response.json() + data = self.raw_response.json() try: - return self._client._process_response_data( - data=data, response_cls=self._response_cls, raw_response=self._raw_response + return self.client._process_response_data( + data=data, response_cls=self.response_cls, raw_response=self.raw_response ) except Exception as exc: raise ValueError(exc) from None # validation error here @property def headers(self): - return self._raw_response.headers + return self.raw_response.headers @property def status_code(self): - return self._raw_response.status_code + return self.raw_response.status_code @property def request(self): - return self._raw_response.request + return self.raw_response.request @property def url(self): - return self._raw_response.url + return self.raw_response.url @property def content(self): - return self._raw_response.content + return self.raw_response.content @property def text(self): - return self._raw_response.text + return self.raw_response.text @property def http_version(self): - return self._raw_response.http_version + return self.raw_response.http_version -@attr.define(init=False) -class BaseClient(t.Generic[InnerClient, StreamType]): - _base_url: httpx.URL = attr.field(converter=_address_converter) - _version: LiteralVersion - _timeout: httpx.Timeout = attr.field(converter=httpx.Timeout) - _max_retries: int - _inner: InnerClient - _default_stream_cls: t.Type[StreamType] - _auth_headers: t.Dict[str, str] = attr.field(init=False) +class BaseClient(pydantic.BaseModel, t.Generic[InnerClient, StreamType]): + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) - def __init__( - self, - *, - base_url: str | httpx.URL, - version: str, - timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, - max_retries: int = MAX_RETRIES, - client: InnerClient, - _default_stream_cls: t.Type[StreamType], - _internal: bool = False, - ): - if not _internal: - raise RuntimeError('Client is reserved to be used internally only.') - self.__attrs_init__(base_url, version, timeout, max_retries, client, _default_stream_cls) + _base_url: httpx.URL = pydantic.PrivateAttr() + version: pydantic.SkipValidation[LiteralVersion] + timeout: httpx.Timeout + max_retries: int + inner: InnerClient + default_stream_cls: pydantic.SkipValidation[t.Type[StreamType]] + _auth_headers: t.Dict[str, str] = pydantic.PrivateAttr() + + @pydantic.field_validator('timeout', mode='before') + @classmethod + def convert_timeout(cls, value: t.Any) -> httpx.Timeout: + return httpx.Timeout(value) + + def __init__(self, base_url: str, **data: t.Any): + super().__init__(**data) + self._base_url = _address_converter(base_url) + + def model_post_init(self, *_: t.Any): self._auth_headers = self._build_auth_headers() def _prepare_url(self, url: str) -> httpx.URL: @@ -212,7 +204,7 @@ class BaseClient(t.Generic[InnerClient, StreamType]): @property def is_closed(self): - return self._inner.is_closed + return self.inner.is_closed @property def is_ready(self): @@ -239,7 +231,7 @@ class BaseClient(t.Generic[InnerClient, StreamType]): @property def user_agent(self): - return f'{self.__class__.__name__}/Python {self._version}' + return f'{self.__class__.__name__}/Python {self.version}' @property def auth_headers(self): @@ -258,7 +250,7 @@ class BaseClient(t.Generic[InnerClient, StreamType]): @property def platform_headers(self): return { - 'X-OpenLLM-Client-Package-Version': self._version, + 'X-OpenLLM-Client-Package-Version': self.version, 'X-OpenLLM-Client-Language': 'Python', 'X-OpenLLM-Client-Runtime': platform.python_implementation(), 'X-OpenLLM-Client-Runtime-Version': platform.python_version(), @@ -267,13 +259,13 @@ class BaseClient(t.Generic[InnerClient, StreamType]): } def _remaining_retries(self, remaining_retries: int | None, options: RequestOptions) -> int: - return remaining_retries if remaining_retries is not None else options.get_max_retries(self._max_retries) + return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries) def _build_headers(self, options: RequestOptions) -> httpx.Headers: return httpx.Headers(_merge_mapping(self._default_headers, options.headers or {})) def _build_request(self, options: RequestOptions) -> httpx.Request: - return self._inner.build_request( + return self.inner.build_request( method=options.method, headers=self._build_headers(options), url=self._prepare_url(options.url), @@ -284,7 +276,7 @@ class BaseClient(t.Generic[InnerClient, StreamType]): def _calculate_retry_timeout( self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None ) -> float: - max_retries = options.get_max_retries(self._max_retries) + max_retries = options.get_max_retries(self.max_retries) # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After try: if headers is not None: @@ -327,7 +319,7 @@ class BaseClient(t.Generic[InnerClient, StreamType]): def _process_response_data( self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response ) -> Response: - return converter.structure(data, response_cls) + return response_cls(**data) def _process_response( self, @@ -348,7 +340,6 @@ class BaseClient(t.Generic[InnerClient, StreamType]): ).parse() -@attr.define(init=False) class Client(BaseClient[httpx.Client, Stream[t.Any]]): def __init__( self, @@ -362,18 +353,17 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]): version=version, timeout=timeout, max_retries=max_retries, - client=httpx.Client(base_url=base_url, timeout=timeout), - _default_stream_cls=Stream, - _internal=True, + inner=httpx.Client(base_url=base_url, timeout=timeout), + default_stream_cls=Stream, ) def close(self): - self._inner.close() + self.inner.close() def __enter__(self): return self - def __exit__(self, *args) -> None: + def __exit__(self, *args: t.Any) -> None: self.close() def __del__(self): @@ -408,7 +398,7 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]): retries = self._remaining_retries(remaining_retries, options) request = self._build_request(options) try: - response = self._inner.send(request, auth=self.auth, stream=stream) + response = self.inner.send(request, auth=self.auth, stream=stream) logger.debug('HTTP [%s, %s]: %i [%s]', request.method, request.url, response.status_code, response.reason_phrase) response.raise_for_status() except httpx.HTTPStatusError as exc: @@ -418,7 +408,7 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]): ) # If the response is streamed then we need to explicitly read the completed response exc.response.read() - raise ValueError(exc.message) from None + raise ValueError(exc) from None except httpx.TimeoutException: if retries > 0: return self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls) @@ -481,7 +471,6 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]): ) -@attr.define(init=False) class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): def __init__( self, @@ -495,18 +484,17 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): version=version, timeout=timeout, max_retries=max_retries, - client=httpx.AsyncClient(base_url=base_url, timeout=timeout), - _default_stream_cls=AsyncStream, - _internal=True, + inner=httpx.AsyncClient(base_url=base_url, timeout=timeout), + default_stream_cls=AsyncStream, ) async def close(self): - await self._inner.aclose() + await self.inner.aclose() async def __aenter__(self): return self - async def __aexit__(self, *args) -> None: + async def __aexit__(self, *args: t.Any) -> None: await self.close() def __del__(self): @@ -545,7 +533,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): request = self._build_request(options) try: - response = await self._inner.send(request, auth=self.auth, stream=stream) + response = await self.inner.send(request, auth=self.auth, stream=stream) logger.debug('HTTP [%s, %s]: %i [%s]', request.method, request.url, response.status_code, response.reason_phrase) response.raise_for_status() except httpx.HTTPStatusError as exc: @@ -555,7 +543,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): ) # If the response is streamed then we need to explicitly read the completed response await exc.response.aread() - raise ValueError(exc.message) from None + raise ValueError(exc) from None except httpx.ConnectTimeout as err: if retries > 0: return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls) @@ -622,3 +610,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): return await self.request( response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls ) + + +Stream.model_rebuild() +AsyncStream.model_rebuild() diff --git a/openllm-client/src/openllm_client/_stream.py b/openllm-client/src/openllm_client/_stream.py index e81a7fb0..c40923ad 100644 --- a/openllm-client/src/openllm_client/_stream.py +++ b/openllm-client/src/openllm_client/_stream.py @@ -1,26 +1,23 @@ from __future__ import annotations -import typing as t -import attr -import httpx -import orjson +import pydantic, httpx, orjson, typing as t if t.TYPE_CHECKING: from ._shim import AsyncClient, Client -Response = t.TypeVar('Response', bound=attr.AttrsInstance) +Response = t.TypeVar('Response', bound=pydantic.BaseModel) -@attr.define(auto_attribs=True) -class Stream(t.Generic[Response]): - _response_cls: t.Type[Response] - _response: httpx.Response - _client: Client - _decoder: SSEDecoder = attr.field(factory=lambda: SSEDecoder()) - _iterator: t.Iterator[Response] = attr.field(init=False) +class Stream(pydantic.BaseModel, t.Generic[Response]): + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + response_cls: t.Type[Response] + response: pydantic.SkipValidation[httpx.Response] + client: Client + _decoder: SSEDecoder = pydantic.PrivateAttr(default_factory=lambda: SSEDecoder()) + _iterator: t.Iterator[Response] = pydantic.PrivateAttr() - def __init__(self, response_cls, response, client): - self.__attrs_init__(response_cls, response, client) + def __init__(self, **data): + super().__init__(**data) self._iterator = self._stream() def __next__(self): @@ -31,28 +28,28 @@ class Stream(t.Generic[Response]): yield item def _iter_events(self) -> t.Iterator[SSE]: - yield from self._decoder.iter(self._response.iter_lines()) + yield from self._decoder.iter(self.response.iter_lines()) def _stream(self) -> t.Iterator[Response]: for sse in self._iter_events(): if sse.data.startswith('[DONE]'): break if sse.event is None: - yield self._client._process_response_data( - data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response + yield self.client._process_response_data( + data=orjson.loads(sse.data), response_cls=self.response_cls, raw_response=self.response ) -@attr.define(auto_attribs=True) -class AsyncStream(t.Generic[Response]): - _response_cls: t.Type[Response] - _response: httpx.Response - _client: AsyncClient - _decoder: SSEDecoder = attr.field(factory=lambda: SSEDecoder()) - _iterator: t.Iterator[Response] = attr.field(init=False) +class AsyncStream(pydantic.BaseModel, t.Generic[Response]): + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + response_cls: t.Type[Response] + response: pydantic.SkipValidation[httpx.Response] + client: AsyncClient + _decoder: SSEDecoder = pydantic.PrivateAttr(default_factory=lambda: SSEDecoder()) + _iterator: t.Iterator[Response] = pydantic.PrivateAttr() - def __init__(self, response_cls, response, client): - self.__attrs_init__(response_cls, response, client) + def __init__(self, **data): + super().__init__(**data) self._iterator = self._stream() async def __anext__(self): @@ -63,7 +60,7 @@ class AsyncStream(t.Generic[Response]): yield item async def _iter_events(self): - async for sse in self._decoder.aiter(self._response.aiter_lines()): + async for sse in self._decoder.aiter(self.response.aiter_lines()): yield sse async def _stream(self) -> t.AsyncGenerator[Response, None]: @@ -71,31 +68,23 @@ class AsyncStream(t.Generic[Response]): if sse.data.startswith('[DONE]'): break if sse.event is None: - yield self._client._process_response_data( - data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response + yield self.client._process_response_data( + data=orjson.loads(sse.data), response_cls=self.response_cls, raw_response=self.response ) -@attr.define -class SSE: - data: str = attr.field(default='') - id: t.Optional[str] = attr.field(default=None) - event: t.Optional[str] = attr.field(default=None) - retry: t.Optional[int] = attr.field(default=None) - - def model_dump(self) -> t.Dict[str, t.Any]: - try: - return orjson.loads(self.data) - except orjson.JSONDecodeError: - raise +class SSE(pydantic.BaseModel): + data: str = pydantic.Field(default='') + id: t.Optional[str] = pydantic.Field(default=None) + event: t.Optional[str] = pydantic.Field(default=None) + retry: t.Optional[int] = pydantic.Field(default=None) -@attr.define(auto_attribs=True) -class SSEDecoder: - _data: t.List[str] = attr.field(factory=list) - _event: t.Optional[str] = None - _retry: t.Optional[int] = None - _last_event_id: t.Optional[str] = None +class SSEDecoder(pydantic.BaseModel): + _data: t.List[str] = pydantic.PrivateAttr(default_factory=list) + event: t.Optional[str] = None + retry: t.Optional[int] = None + last_event_id: t.Optional[str] = None def iter(self, iterator: t.Iterator[str]) -> t.Iterator[SSE]: for line in iterator: @@ -112,10 +101,10 @@ class SSEDecoder: def decode(self, line: str) -> SSE | None: # NOTE: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation if not line: - if all(not a for a in [self._event, self._data, self._retry, self._last_event_id]): + if all(not a for a in [self.event, self._data, self.retry, self.last_event_id]): return None - sse = SSE(data='\n'.join(self._data), event=self._event, retry=self._retry, id=self._last_event_id) - self._event, self._data, self._retry = None, [], None + sse = SSE(data='\n'.join(self._data), event=self.event, retry=self.retry, id=self.last_event_id) + self.event, self._data, self.retry = None, [], None return sse if line.startswith(':'): return None @@ -123,17 +112,17 @@ class SSEDecoder: if value.startswith(' '): value = value[1:] if field == 'event': - self._event = value + self.event = value elif field == 'data': self._data.append(value) elif field == 'id': if '\0' in value: pass else: - self._last_event_id = value + self.last_event_id = value elif field == 'retry': try: - self._retry = int(value) + self.retry = int(value) except (TypeError, ValueError): pass else: diff --git a/openllm-client/src/openllm_client/_typing_compat.py b/openllm-client/src/openllm_client/_typing_compat.py index 15d86f8a..48f3c37d 100644 --- a/openllm-client/src/openllm_client/_typing_compat.py +++ b/openllm-client/src/openllm_client/_typing_compat.py @@ -7,6 +7,7 @@ from openllm_core._typing_compat import ( Required as Required, Self as Self, TypeGuard as TypeGuard, + TypedDict as TypedDict, dataclass_transform as dataclass_transform, overload as overload, ) diff --git a/openllm-python/README.md b/openllm-python/README.md index dc3117f2..e2b38c0e 100644 --- a/openllm-python/README.md +++ b/openllm-python/README.md @@ -27,7 +27,7 @@ OpenLLM helps developers **run any open-source LLMs**, such as Llama 2 and Mistr - 🚂 Support a wide range of open-source LLMs including LLMs fine-tuned with your own data - ⛓️ OpenAI compatible API endpoints for seamless transition from your LLM app to open-source LLMs - 🔥 State-of-the-art serving and inference performance -- 🎯 Simplified cloud deployment via [BentoML](www.bentoml.com) +- 🎯 Simplified cloud deployment via [BentoML](https://www.bentoml.com) diff --git a/openllm-python/pyproject.toml b/openllm-python/pyproject.toml index 5051d5e2..ea972b79 100644 --- a/openllm-python/pyproject.toml +++ b/openllm-python/pyproject.toml @@ -150,45 +150,6 @@ only-include = ["src/openllm", "src/openllm_cli", "src/_openllm_tiny"] sources = ["src"] [tool.hatch.build.targets.sdist] exclude = ["/.git_archival.txt", "tests", "/.python-version-default"] -[tool.hatch.build.targets.wheel.hooks.mypyc] -dependencies = [ - "hatch-mypyc==0.16.0", - "mypy==1.7.0", - # avoid https://github.com/pallets/click/issues/2558 - "click==8.1.3", - "bentoml==1.1.9", - "transformers>=4.32.1", - "pandas-stubs", - "types-psutil", - "types-tabulate", - "types-PyYAML", - "types-protobuf", -] -enable-by-default = false -exclude = ["src/_openllm_tiny/_service.py", "src/openllm/utils/__init__.py"] -include = [ - "src/openllm/__init__.py", - "src/openllm/_quantisation.py", - "src/openllm/_generation.py", - "src/openllm/exceptions.py", - "src/openllm/testing.py", - "src/openllm/utils", -] -# NOTE: This is consistent with pyproject.toml -mypy-args = [ - "--strict", - # this is because all transient library doesn't have types - "--follow-imports=skip", - "--allow-subclassing-any", - "--check-untyped-defs", - "--ignore-missing-imports", - "--no-warn-return-any", - "--warn-unreachable", - "--no-warn-no-return", - "--no-warn-unused-ignores", -] -options = { verbose = true, strip_asserts = true, debug_level = "2", opt_level = "3", include_runtime_files = true } -require-runtime-dependencies = true [tool.hatch.metadata.hooks.fancy-pypi-readme] content-type = "text/markdown" # PyPI doesn't support the tag. diff --git a/openllm-python/src/openllm/__main__.py b/openllm-python/src/openllm/__main__.py index 72092181..3634a55c 100644 --- a/openllm-python/src/openllm/__main__.py +++ b/openllm-python/src/openllm/__main__.py @@ -1,4 +1,4 @@ if __name__ == '__main__': - from openllm_cli.entrypoint import cli + from _openllm_tiny._entrypoint import cli cli() diff --git a/openllm-python/src/openllm/entrypoints/openai.py b/openllm-python/src/openllm/entrypoints/openai.py index 80db02b9..54c5ddc1 100644 --- a/openllm-python/src/openllm/entrypoints/openai.py +++ b/openllm-python/src/openllm/entrypoints/openai.py @@ -70,7 +70,7 @@ def error_response(status_code, message): ) -async def check_model(request, model): +async def check_model(request, model): # noqa if request.model == model: return None return error_response( diff --git a/openllm-python/tests/__init__.py b/openllm-python/tests/__init__.py index d423c23a..e69de29b 100644 --- a/openllm-python/tests/__init__.py +++ b/openllm-python/tests/__init__.py @@ -1,9 +0,0 @@ -from __future__ import annotations -import os - -from hypothesis import HealthCheck, settings - -settings.register_profile('CI', settings(suppress_health_check=[HealthCheck.too_slow]), deadline=None) - -if 'CI' in os.environ: - settings.load_profile('CI') diff --git a/openllm-python/tests/_data.py b/openllm-python/tests/_data.py new file mode 100644 index 00000000..a8e93a0f --- /dev/null +++ b/openllm-python/tests/_data.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import typing as t +from openllm_core._typing_compat import TypedDict +from datasets import load_dataset + +if t.TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + +FIXED_OUTPUT_LENGTH = 128 + + +class DatasetEntry(TypedDict): + human: str + gpt: str + + +class SampledRequest(TypedDict): + prompt: str + prompt_length: int + output_length: int + + +def prepare_sharegpt_request( + num_requests: int, tokenizer: PreTrainedTokenizerBase, max_output_length: int | None = None +) -> list[SampledRequest]: + def transform(examples) -> DatasetEntry: + human, gpt = [], [] + for example in examples['conversations']: + human.append(example[0]['value']) + gpt.append(example[1]['value']) + return {'human': human, 'gpt': gpt} + + def process(examples, tokenizer, max_output_length: t.Optional[int]): + # Tokenize the 'human' and 'gpt' values in batches + prompt_token_ids = tokenizer(examples['human']).input_ids + completion_token_ids = tokenizer(examples['gpt']).input_ids + + # Create the transformed entries + return { + 'prompt': examples['human'], + 'prompt_length': [len(ids) for ids in prompt_token_ids], + 'output_length': [ + len(ids) if max_output_length is None else FIXED_OUTPUT_LENGTH for ids in completion_token_ids + ], + } + + def filter_length(examples) -> list[bool]: + result = [] + for prompt_length, output_length in zip(examples['prompt_length'], examples['output_length']): + if prompt_length < 4 or output_length < 4: + result.append(False) + elif prompt_length > 1024 or prompt_length + output_length > 2048: + result.append(False) + else: + result.append(True) + return result + + return ( + ( + dataset := load_dataset( + 'anon8231489123/ShareGPT_Vicuna_unfiltered', + data_files='ShareGPT_V3_unfiltered_cleaned_split.json', + split='train', + ) + ) + .filter(lambda example: len(example['conversations']) >= 2, num_proc=8) + .map(transform, remove_columns=dataset.column_names, batched=True) + .map( + process, + fn_kwargs={'tokenizer': tokenizer, 'max_output_length': max_output_length}, + remove_columns=['human', 'gpt'], + batched=True, + ) + .filter(filter_length, batched=True) + .shuffle(seed=42) + .to_list()[:num_requests] + ) diff --git a/openllm-python/tests/_strategies/__init__.py b/openllm-python/tests/_strategies/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/openllm-python/tests/_strategies/_configuration.py b/openllm-python/tests/_strategies/_configuration.py deleted file mode 100644 index 25e12d47..00000000 --- a/openllm-python/tests/_strategies/_configuration.py +++ /dev/null @@ -1,60 +0,0 @@ -from __future__ import annotations -import logging -import typing as t - -from hypothesis import strategies as st - -import openllm -from openllm_core._configuration import ModelSettings - -logger = logging.getLogger(__name__) - - -@st.composite -def model_settings(draw: st.DrawFn): - """Strategy for generating ModelSettings objects.""" - kwargs: dict[str, t.Any] = { - 'default_id': st.text(min_size=1), - 'model_ids': st.lists(st.text(), min_size=1), - 'architecture': st.text(min_size=1), - 'url': st.text(), - 'trust_remote_code': st.booleans(), - 'requirements': st.none() | st.lists(st.text(), min_size=1), - 'model_type': st.sampled_from(['causal_lm', 'seq2seq_lm']), - 'name_type': st.sampled_from(['dasherize', 'lowercase']), - 'timeout': st.integers(min_value=3600), - 'workers_per_resource': st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)), - } - return draw(st.builds(ModelSettings, **kwargs)) - - -def make_llm_config( - cls_name: str, - dunder_config: dict[str, t.Any] | ModelSettings, - fields: tuple[tuple[t.LiteralString, str, t.Any], ...] | None = None, - generation_fields: tuple[tuple[t.LiteralString, t.Any], ...] | None = None, -) -> type[openllm.LLMConfig]: - globs: dict[str, t.Any] = {'openllm': openllm} - _config_args: list[str] = [] - lines: list[str] = [f'class {cls_name}Config(openllm.LLMConfig):'] - for attr, value in dunder_config.items(): - _config_args.append(f'"{attr}": __attr_{attr}') - globs[f'_{cls_name}Config__attr_{attr}'] = value - lines.append(f' __config__ = {{ {", ".join(_config_args)} }}') - if fields is not None: - for field, type_, default in fields: - lines.append(f' {field}: {type_} = openllm.LLMConfig.Field({default!r})') - if generation_fields is not None: - generation_lines = ['class GenerationConfig:'] - for field, default in generation_fields: - generation_lines.append(f' {field} = {default!r}') - lines.extend((' ' + line for line in generation_lines)) - - script = '\n'.join(lines) - - if openllm.utils.DEBUG: - logger.info('Generated class %s:\n%s', cls_name, script) - - eval(compile(script, 'name', 'exec'), globs) - - return globs[f'{cls_name}Config'] diff --git a/openllm-python/tests/configuration_test.py b/openllm-python/tests/configuration_test.py deleted file mode 100644 index fafa4098..00000000 --- a/openllm-python/tests/configuration_test.py +++ /dev/null @@ -1,152 +0,0 @@ -from __future__ import annotations -import contextlib -import os -import typing as t -from unittest import mock - -import attr -import pytest -from hypothesis import assume, given, strategies as st - -import openllm -from openllm_core._configuration import GenerationConfig, ModelSettings, field_env_key - -from ._strategies._configuration import make_llm_config, model_settings - - -def test_forbidden_access(): - cl_ = make_llm_config( - 'ForbiddenAccess', - { - 'default_id': 'huggingface/t5-tiny-testing', - 'model_ids': ['huggingface/t5-tiny-testing', 'bentoml/t5-tiny-testing'], - 'architecture': 'PreTrainedModel', - 'requirements': ['bentoml'], - }, - ) - - assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), '__config__') - assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), 'GenerationConfig') - assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), 'SamplingParams') - assert openllm.utils.lenient_issubclass(cl_.__openllm_generation_class__, GenerationConfig) - - -@given(model_settings()) -def test_class_normal_gen(gen_settings: ModelSettings): - assume(gen_settings['default_id'] and all(i for i in gen_settings['model_ids'])) - cl_: type[openllm.LLMConfig] = make_llm_config('NotFullLLM', gen_settings) - assert issubclass(cl_, openllm.LLMConfig) - for key in gen_settings: - assert object.__getattribute__(cl_, f'__openllm_{key}__') == gen_settings.__getitem__(key) - - -@given(model_settings(), st.integers()) -def test_simple_struct_dump(gen_settings: ModelSettings, field1: int): - cl_ = make_llm_config('IdempotentLLM', gen_settings, fields=(('field1', 'float', field1),)) - assert cl_().model_dump()['field1'] == field1 - - -@given(model_settings(), st.integers()) -def test_config_derivation(gen_settings: ModelSettings, field1: int): - cl_ = make_llm_config('IdempotentLLM', gen_settings, fields=(('field1', 'float', field1),)) - new_cls = cl_.model_derivate('DerivedLLM', default_id='asdfasdf') - assert new_cls.__openllm_default_id__ == 'asdfasdf' - - -@given(model_settings()) -def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings): - cl_ = make_llm_config('AttrsProtocolLLM', gen_settings) - assert attr.has(cl_) - - -@given( - model_settings(), - st.integers(max_value=283473), - st.floats(min_value=0.0, max_value=1.0), - st.integers(max_value=283473), - st.floats(min_value=0.0, max_value=1.0), -) -def test_complex_struct_dump( - gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float -): - cl_ = make_llm_config( - 'ComplexLLM', - gen_settings, - fields=(('field1', 'float', field1),), - generation_fields=(('temperature', temperature),), - ) - sent = cl_() - assert sent.model_dump()['field1'] == field1 - assert sent.model_dump()['generation_config']['temperature'] == temperature - assert sent.model_dump(flatten=True)['field1'] == field1 - assert sent.model_dump(flatten=True)['temperature'] == temperature - - passed = cl_(field1=input_field1, temperature=input_temperature) - assert passed.model_dump()['field1'] == input_field1 - assert passed.model_dump()['generation_config']['temperature'] == input_temperature - assert passed.model_dump(flatten=True)['field1'] == input_field1 - assert passed.model_dump(flatten=True)['temperature'] == input_temperature - - pas_nested = cl_(generation_config={'temperature': input_temperature}, field1=input_field1) - assert pas_nested.model_dump()['field1'] == input_field1 - assert pas_nested.model_dump()['generation_config']['temperature'] == input_temperature - - -@contextlib.contextmanager -def patch_env(**attrs: t.Any): - with mock.patch.dict(os.environ, attrs, clear=True): - yield - - -def test_struct_envvar(): - with patch_env(**{field_env_key('field1'): '4', field_env_key('temperature', suffix='generation'): '0.2'}): - - class EnvLLM(openllm.LLMConfig): - __config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'} - field1: int = 2 - - class GenerationConfig: - temperature: float = 0.8 - - sent = EnvLLM.model_construct_env() - assert sent.field1 == 4 - assert sent['temperature'] == 0.2 - - overwrite_default = EnvLLM() - assert overwrite_default.field1 == 4 - assert overwrite_default['temperature'] == 0.2 - - -def test_struct_provided_fields(): - class EnvLLM(openllm.LLMConfig): - __config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'} - field1: int = 2 - - class GenerationConfig: - temperature: float = 0.8 - - sent = EnvLLM.model_construct_env(field1=20, temperature=0.4) - assert sent.field1 == 20 - assert sent.generation_config.temperature == 0.4 - - -def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as mk: - mk.setenv(field_env_key('field1'), str(4.0)) - mk.setenv(field_env_key('temperature', suffix='generation'), str(0.2)) - sent = make_llm_config( - 'OverwriteWithEnvAvailable', - {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'}, - fields=(('field1', 'float', 3.0),), - ).model_construct_env(field1=20.0, temperature=0.4) - assert sent.generation_config.temperature == 0.4 - assert sent.field1 == 20.0 - - -@pytest.mark.parametrize('model_name', openllm.CONFIG_MAPPING.keys()) -def test_configuration_dict_protocol(model_name: str): - config = openllm.AutoConfig.for_model(model_name) - assert isinstance(config.items(), list) - assert isinstance(config.keys(), list) - assert isinstance(config.values(), list) - assert isinstance(dict(config), dict) diff --git a/openllm-python/tests/conftest.py b/openllm-python/tests/conftest.py index e49b2656..aaa3f895 100644 --- a/openllm-python/tests/conftest.py +++ b/openllm-python/tests/conftest.py @@ -1,42 +1,16 @@ from __future__ import annotations -import itertools -import os -import typing as t -import pytest - -import openllm - -if t.TYPE_CHECKING: - from openllm_core._typing_compat import LiteralBackend - -_MODELING_MAPPING = { - 'flan_t5': 'google/flan-t5-small', - 'opt': 'facebook/opt-125m', - 'baichuan': 'baichuan-inc/Baichuan-7B', -} -_PROMPT_MAPPING = { - 'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?' -} +import pytest, typing as t -def parametrise_local_llm(model: str) -> t.Generator[tuple[str, openllm.LLM[t.Any, t.Any]], None, None]: - if model not in _MODELING_MAPPING: - pytest.skip(f"'{model}' is not yet supported in framework testing.") - backends: tuple[LiteralBackend, ...] = ('pt',) - for backend, prompt in itertools.product(backends, _PROMPT_MAPPING.keys()): - yield prompt, openllm.LLM(_MODELING_MAPPING[model], backend=backend) - - -def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: - if os.getenv('GITHUB_ACTIONS') is None: - if 'prompt' in metafunc.fixturenames and 'llm' in metafunc.fixturenames: - metafunc.parametrize( - 'prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])] - ) - - -def pytest_sessionfinish(session: pytest.Session, exitstatus: int): - # If no tests are collected, pytest exists with code 5, which makes the CI fail. - if exitstatus == 5: - session.exitstatus = 0 +@pytest.fixture( + scope='function', + name='model_id', + params={ + 'meta-llama/Meta-Llama-3-8B-Instruct', + 'casperhansen/llama-3-70b-instruct-awq', + 'TheBloke/Nous-Hermes-2-Mixtral-8x7B-DPO-AWQ', + }, +) +def fixture_model_id(request) -> t.Generator[str, None, None]: + yield request.param diff --git a/openllm-python/tests/models_test.py b/openllm-python/tests/models_test.py deleted file mode 100644 index fb983d7c..00000000 --- a/openllm-python/tests/models_test.py +++ /dev/null @@ -1,29 +0,0 @@ -from __future__ import annotations -import os -import typing as t - -import pytest - -if t.TYPE_CHECKING: - import openllm - - -@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI') -def test_flan_t5_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]): - assert llm.generate(prompt) - - assert llm.generate(prompt, temperature=0.8, top_p=0.23) - - -@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI') -def test_opt_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]): - assert llm.generate(prompt) - - assert llm.generate(prompt, temperature=0.9, top_k=8) - - -@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI') -def test_baichuan_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]): - assert llm.generate(prompt) - - assert llm.generate(prompt, temperature=0.95) diff --git a/openllm-python/tests/package_test.py b/openllm-python/tests/package_test.py deleted file mode 100644 index 793f221a..00000000 --- a/openllm-python/tests/package_test.py +++ /dev/null @@ -1,60 +0,0 @@ -from __future__ import annotations -import functools -import os -import typing as t - -import pytest - -import openllm -from bentoml._internal.configuration.containers import BentoMLContainer - -if t.TYPE_CHECKING: - from pathlib import Path - -HF_INTERNAL_T5_TESTING = 'hf-internal-testing/tiny-random-t5' - -actions_xfail = functools.partial( - pytest.mark.xfail, - condition=os.getenv('GITHUB_ACTIONS') is not None, - reason='Marking GitHub Actions to xfail due to flakiness and building environment not isolated.', -) - - -@actions_xfail -def test_general_build_with_internal_testing(): - bento_store = BentoMLContainer.bento_store.get() - - llm = openllm.LLM(model_id=HF_INTERNAL_T5_TESTING, serialisation='legacy') - bento = openllm.build('flan-t5', model_id=HF_INTERNAL_T5_TESTING) - - assert llm.llm_type == bento.info.labels['_type'] - assert llm.__llm_backend__ == bento.info.labels['_framework'] - - bento = openllm.build('flan-t5', model_id=HF_INTERNAL_T5_TESTING) - assert len(bento_store.list(bento.tag)) == 1 - - -@actions_xfail -def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory): - local_path = tmp_path_factory.mktemp('local_t5') - llm = openllm.LLM(model_id=HF_INTERNAL_T5_TESTING, serialisation='legacy') - - llm.model.save_pretrained(str(local_path)) - llm._tokenizer.save_pretrained(str(local_path)) - - assert openllm.build('flan-t5', model_id=local_path.resolve().__fspath__(), model_version='local') - - -@pytest.fixture() -def dockerfile_template(tmp_path_factory: pytest.TempPathFactory): - file = tmp_path_factory.mktemp('dockerfiles') / 'Dockerfile.template' - file.write_text( - "{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}" - ) - return file - - -@pytest.mark.usefixtures('dockerfile_template') -@actions_xfail -def test_build_with_custom_dockerfile(dockerfile_template: Path): - assert openllm.build('flan-t5', model_id=HF_INTERNAL_T5_TESTING, dockerfile_template=str(dockerfile_template)) diff --git a/openllm-python/tests/regression_test.py b/openllm-python/tests/regression_test.py new file mode 100644 index 00000000..cd37c77a --- /dev/null +++ b/openllm-python/tests/regression_test.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import pytest, subprocess, sys, openllm, bentoml, asyncio +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam + +SERVER_PORT = 53822 + + +@pytest.mark.asyncio +async def test_openai_compatible(model_id: str): + server = subprocess.Popen([sys.executable, '-m', 'openllm', 'start', model_id, '--port', str(SERVER_PORT)]) + await asyncio.sleep(5) + with bentoml.SyncHTTPClient(f'http://127.0.0.1:{SERVER_PORT}', server_ready_timeout=90) as client: + assert client.is_ready(30) + + try: + client = AsyncOpenAI(api_key='na', base_url=f'http://127.0.0.1:{SERVER_PORT}/v1') + serve_model = (await client.models.list()).data[0].id + assert serve_model == openllm.utils.normalise_model_name(model_id) + streamable = await client.chat.completions.create( + model=serve_model, + max_tokens=512, + stream=False, + messages=[ + ChatCompletionSystemMessageParam( + role='system', content='You will be the writing assistant that assume the tone of Ernest Hemmingway.' + ), + ChatCompletionUserMessageParam( + role='user', content='Comment on why Camus thinks we should revolt against life absurdity.' + ), + ], + ) + assert streamable is not None + finally: + server.terminate() + + +@pytest.mark.asyncio +async def test_generate_endpoint(model_id: str): + server = subprocess.Popen([sys.executable, '-m', 'openllm', 'start', model_id, '--port', str(SERVER_PORT)]) + await asyncio.sleep(5) + + with bentoml.SyncHTTPClient(f'http://127.0.0.1:{SERVER_PORT}', server_ready_timeout=90) as client: + assert client.is_ready(30) + + try: + client = openllm.AsyncHTTPClient(f'http://127.0.0.1:{SERVER_PORT}', api_version='v1') + assert await client.health() + + response = await client.generate( + 'Tell me more about Apple as a company', stop='technology', llm_config={'temperature': 0.5, 'top_p': 0.2} + ) + assert response is not None + finally: + server.terminate() diff --git a/openllm-python/tests/strategies_test.py b/openllm-python/tests/strategies_test.py deleted file mode 100644 index 6b95ac0d..00000000 --- a/openllm-python/tests/strategies_test.py +++ /dev/null @@ -1,185 +0,0 @@ -from __future__ import annotations -import os -import typing as t - -import pytest - -import bentoml -from openllm import _strategies as strategy -from openllm._strategies import CascadingResourceStrategy, NvidiaGpuResource, get_resource - -if t.TYPE_CHECKING: - from _pytest.monkeypatch import MonkeyPatch - - -def test_nvidia_gpu_resource_from_env(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as mcls: - mcls.setenv('CUDA_VISIBLE_DEVICES', '0,1') - resource = NvidiaGpuResource.from_system() - assert len(resource) == 2 - assert resource == ['0', '1'] - mcls.delenv('CUDA_VISIBLE_DEVICES') - - -def test_nvidia_gpu_cutoff_minus(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as mcls: - mcls.setenv('CUDA_VISIBLE_DEVICES', '0,2,-1,1') - resource = NvidiaGpuResource.from_system() - assert len(resource) == 2 - assert resource == ['0', '2'] - mcls.delenv('CUDA_VISIBLE_DEVICES') - - -def test_nvidia_gpu_neg_val(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as mcls: - mcls.setenv('CUDA_VISIBLE_DEVICES', '-1') - resource = NvidiaGpuResource.from_system() - assert len(resource) == 0 - assert resource == [] - mcls.delenv('CUDA_VISIBLE_DEVICES') - - -def test_nvidia_gpu_parse_literal(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as mcls: - mcls.setenv('CUDA_VISIBLE_DEVICES', 'GPU-5ebe9f43-ac33420d4628') - resource = NvidiaGpuResource.from_system() - assert len(resource) == 1 - assert resource == ['GPU-5ebe9f43-ac33420d4628'] - mcls.delenv('CUDA_VISIBLE_DEVICES') - with monkeypatch.context() as mcls: - mcls.setenv('CUDA_VISIBLE_DEVICES', 'GPU-5ebe9f43,GPU-ac33420d4628') - resource = NvidiaGpuResource.from_system() - assert len(resource) == 2 - assert resource == ['GPU-5ebe9f43', 'GPU-ac33420d4628'] - mcls.delenv('CUDA_VISIBLE_DEVICES') - with monkeypatch.context() as mcls: - mcls.setenv('CUDA_VISIBLE_DEVICES', 'GPU-5ebe9f43,-1,GPU-ac33420d4628') - resource = NvidiaGpuResource.from_system() - assert len(resource) == 1 - assert resource == ['GPU-5ebe9f43'] - mcls.delenv('CUDA_VISIBLE_DEVICES') - with monkeypatch.context() as mcls: - mcls.setenv('CUDA_VISIBLE_DEVICES', 'MIG-GPU-5ebe9f43-ac33420d4628') - resource = NvidiaGpuResource.from_system() - assert len(resource) == 1 - assert resource == ['MIG-GPU-5ebe9f43-ac33420d4628'] - mcls.delenv('CUDA_VISIBLE_DEVICES') - - -@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='skip GPUs test on CI') -def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as mcls: - # to make this tests works with system that has GPU - mcls.setenv('CUDA_VISIBLE_DEVICES', '') - assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests - - assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1]).match( - 'Input list should be all string type.' - ) - assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match('Input list should be all string type.') - assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match( - 'Failed to parse available GPUs UUID' - ) - - -def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as mcls: - # to make this tests works with system that has GPU - mcls.setenv('CUDA_VISIBLE_DEVICES', '') - assert NvidiaGpuResource.from_spec(1) == ['0'] - assert NvidiaGpuResource.from_spec('5') == ['0', '1', '2', '3', '4'] - assert NvidiaGpuResource.from_spec(1) == ['0'] - assert NvidiaGpuResource.from_spec(2) == ['0', '1'] - assert NvidiaGpuResource.from_spec('3') == ['0', '1', '2'] - assert NvidiaGpuResource.from_spec([1, 3]) == ['1', '3'] - assert NvidiaGpuResource.from_spec(['1', '3']) == ['1', '3'] - assert NvidiaGpuResource.from_spec(-1) == [] - assert NvidiaGpuResource.from_spec('-1') == [] - assert NvidiaGpuResource.from_spec('') == [] - assert NvidiaGpuResource.from_spec('-2') == [] - assert NvidiaGpuResource.from_spec('GPU-288347ab') == ['GPU-288347ab'] - assert NvidiaGpuResource.from_spec('GPU-288347ab,-1,GPU-ac33420d4628') == ['GPU-288347ab'] - assert NvidiaGpuResource.from_spec('GPU-288347ab,GPU-ac33420d4628') == ['GPU-288347ab', 'GPU-ac33420d4628'] - assert NvidiaGpuResource.from_spec('MIG-GPU-288347ab') == ['MIG-GPU-288347ab'] - - with pytest.raises(TypeError): - NvidiaGpuResource.from_spec((1, 2, 3)) - with pytest.raises(TypeError): - NvidiaGpuResource.from_spec(1.5) - with pytest.raises(ValueError): - assert NvidiaGpuResource.from_spec(-2) - - -class GPURunnable(bentoml.Runnable): - SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu') - - -def unvalidated_get_resource(x: dict[str, t.Any], y: str, validate: bool = False): - return get_resource(x, y, validate=validate) - - -@pytest.mark.parametrize('gpu_type', ['nvidia.com/gpu', 'amd.com/gpu']) -def test_cascade_strategy_worker_count(monkeypatch: MonkeyPatch, gpu_type: str): - monkeypatch.setattr(strategy, 'get_resource', unvalidated_get_resource) - assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: 2}, 1) == 1 - assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 1) == 1 - - assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 0.5) == 1 - assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 9]}, 0.5) == 1 - assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5) == 1 - assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 5, 7, 8, 9]}, 0.4) == 1 - - -@pytest.mark.parametrize('gpu_type', ['nvidia.com/gpu', 'amd.com/gpu']) -def test_cascade_strategy_worker_env(monkeypatch: MonkeyPatch, gpu_type: str): - monkeypatch.setattr(strategy, 'get_resource', unvalidated_get_resource) - - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 0) - assert envs.get('CUDA_VISIBLE_DEVICES') == '0' - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 1) - assert envs.get('CUDA_VISIBLE_DEVICES') == '1' - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 1, 1) - assert envs.get('CUDA_VISIBLE_DEVICES') == '7' - - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 0) - assert envs.get('CUDA_VISIBLE_DEVICES') == '0' - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 1) - assert envs.get('CUDA_VISIBLE_DEVICES') == '0' - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 2) - assert envs.get('CUDA_VISIBLE_DEVICES') == '1' - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 2, 1) - assert envs.get('CUDA_VISIBLE_DEVICES') == '2' - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 2, 2) - assert envs.get('CUDA_VISIBLE_DEVICES') == '7' - - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 0.5, 0) - assert envs.get('CUDA_VISIBLE_DEVICES') == '2,7' - - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5, 0) - assert envs.get('CUDA_VISIBLE_DEVICES') == '2,7' - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5, 1) - assert envs.get('CUDA_VISIBLE_DEVICES') == '8,9' - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.25, 0) - assert envs.get('CUDA_VISIBLE_DEVICES') == '2,7,8,9' - - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 0) - assert envs.get('CUDA_VISIBLE_DEVICES') == '2,6' - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 1) - assert envs.get('CUDA_VISIBLE_DEVICES') == '7,8' - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 2) - assert envs.get('CUDA_VISIBLE_DEVICES') == '9' - - -@pytest.mark.parametrize('gpu_type', ['nvidia.com/gpu', 'amd.com/gpu']) -def test_cascade_strategy_disabled_via_env(monkeypatch: MonkeyPatch, gpu_type: str): - monkeypatch.setattr(strategy, 'get_resource', unvalidated_get_resource) - - monkeypatch.setenv('CUDA_VISIBLE_DEVICES', '') - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 0) - assert envs.get('CUDA_VISIBLE_DEVICES') == '' - monkeypatch.delenv('CUDA_VISIBLE_DEVICES') - - monkeypatch.setenv('CUDA_VISIBLE_DEVICES', '-1') - envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 1) - assert envs.get('CUDA_VISIBLE_DEVICES') == '-1' - monkeypatch.delenv('CUDA_VISIBLE_DEVICES') diff --git a/pyproject.toml b/pyproject.toml index df010b02..dbd4bde2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,7 +157,7 @@ toplevel = ["openllm"] [tool.pytest.ini_options] -addopts = ["-rfEX", "-pno:warnings", "--snapshot-warn-unused"] +addopts = ["-rfEX", "-pno:warnings"] python_files = ["test_*.py", "*_test.py"] testpaths = ["openllm-python/tests"]