mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-23 08:28:24 -04:00
tests: add additional basic testing (#982)
* chore: update rebase tests Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * chore: update partial clients before removing Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * fix: update clients parsing logics to work with 0.5 Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * chore: ignore ci runs as to run locally Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * chore: update async client tests Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> * chore: update pre-commit Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
290
.github/workflows/ci.yml
vendored
290
.github/workflows/ci.yml
vendored
@@ -1,146 +1,144 @@
|
||||
name: Continuous Integration
|
||||
on:
|
||||
workflow_call:
|
||||
push:
|
||||
branches: [main]
|
||||
paths-ignore:
|
||||
- 'docs/**'
|
||||
- 'bazel/**'
|
||||
- 'typings/**'
|
||||
- '*.md'
|
||||
- 'changelog.d/**'
|
||||
- 'assets/**'
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths-ignore:
|
||||
- 'docs/**'
|
||||
- 'bazel/**'
|
||||
- 'typings/**'
|
||||
- '*.md'
|
||||
- 'changelog.d/**'
|
||||
- 'assets/**'
|
||||
env:
|
||||
LINES: 120
|
||||
COLUMNS: 120
|
||||
OPENLLM_DO_NOT_TRACK: True
|
||||
PYTHONUNBUFFERED: '1'
|
||||
HATCH_VERBOSE: 2
|
||||
# https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun
|
||||
defaults:
|
||||
run:
|
||||
shell: bash --noprofile --norc -exo pipefail {0}
|
||||
jobs:
|
||||
tests:
|
||||
runs-on: ${{ matrix.os }}
|
||||
if: ${{ github.event_name == 'pull_request' || github.event_name == 'push'|| github.event_name == 'workflow_call' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
python-version: ['3.8', '3.11']
|
||||
exclude:
|
||||
- os: 'windows-latest'
|
||||
name: tests (${{ matrix.python-version }}.${{ matrix.os }})
|
||||
steps:
|
||||
- uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # ratchet:actions/checkout@v4.1.6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- uses: bentoml/setup-bentoml-action@862aa8fa0e0c3793fcca4bfe7a62717a497417e4 # ratchet:bentoml/setup-bentoml-action@v1
|
||||
with:
|
||||
bentoml-version: 'main'
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Run tests
|
||||
run: hatch run tests:python
|
||||
- name: Disambiguate coverage filename
|
||||
run: mv .coverage ".coverage.${{ matrix.os }}.${{ matrix.python-version }}"
|
||||
- name: Upload coverage data
|
||||
uses: actions/upload-artifact@a8a3f3ad30e3422c9c7b888a15615d19a852ae32 # ratchet:actions/upload-artifact@v3
|
||||
with:
|
||||
name: coverage-data
|
||||
path: .coverage.*
|
||||
coverage:
|
||||
name: report-coverage
|
||||
runs-on: ubuntu-latest
|
||||
if: false
|
||||
needs: tests
|
||||
steps:
|
||||
- uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # ratchet:actions/checkout@v4.1.6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- uses: bentoml/setup-bentoml-action@862aa8fa0e0c3793fcca4bfe7a62717a497417e4 # ratchet:bentoml/setup-bentoml-action@v1
|
||||
with:
|
||||
bentoml-version: 'main'
|
||||
python-version-file: .python-version-default
|
||||
- name: Download coverage data
|
||||
uses: actions/download-artifact@9bc31d5ccc31df68ecc42ccf4149144866c47d8a # ratchet:actions/download-artifact@v3
|
||||
with:
|
||||
name: coverage-data
|
||||
- name: Combine coverage data
|
||||
run: hatch run coverage:combine
|
||||
- name: Export coverage reports
|
||||
run: |
|
||||
hatch run coverage:report-xml openllm-python
|
||||
hatch run coverage:report-uncovered-html openllm-python
|
||||
- name: Upload uncovered HTML report
|
||||
uses: actions/upload-artifact@a8a3f3ad30e3422c9c7b888a15615d19a852ae32 # ratchet:actions/upload-artifact@v3
|
||||
with:
|
||||
name: uncovered-html-report
|
||||
path: htmlcov
|
||||
- name: Generate coverage summary
|
||||
run: hatch run coverage:generate-summary
|
||||
- name: Write coverage summary report
|
||||
if: github.event_name == 'pull_request'
|
||||
run: hatch run coverage:write-summary-report
|
||||
- name: Update coverage pull request comment
|
||||
if: github.event_name == 'pull_request' && !github.event.pull_request.head.repo.fork
|
||||
uses: marocchino/sticky-pull-request-comment@331f8f5b4215f0445d3c07b4967662a32a2d3e31 # ratchet:marocchino/sticky-pull-request-comment@v2
|
||||
with:
|
||||
path: coverage-report.md
|
||||
cli-benchmark:
|
||||
name: Check for CLI responsiveness
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
HYPERFINE_VERSION: '1.12.0'
|
||||
steps:
|
||||
- uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # ratchet:actions/checkout@v4.1.6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Install hyperfine
|
||||
run: |
|
||||
wget https://github.com/sharkdp/hyperfine/releases/download/v${HYPERFINE_VERSION}/hyperfine_${HYPERFINE_VERSION}_amd64.deb
|
||||
sudo dpkg -i hyperfine_${HYPERFINE_VERSION}_amd64.deb
|
||||
- uses: bentoml/setup-bentoml-action@862aa8fa0e0c3793fcca4bfe7a62717a497417e4 # ratchet:bentoml/setup-bentoml-action@v1
|
||||
with:
|
||||
bentoml-version: 'main'
|
||||
python-version-file: .python-version-default
|
||||
- name: Install self
|
||||
run: bash local.sh
|
||||
- name: Speed
|
||||
run: hyperfine -m 100 --warmup 10 openllm
|
||||
brew-dry-run:
|
||||
name: Running dry-run tests for brew
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Install tap and dry-run
|
||||
run: |
|
||||
brew tap bentoml/openllm https://github.com/bentoml/openllm
|
||||
brew install openllm
|
||||
openllm --help
|
||||
openllm models --show-available
|
||||
evergreen: # https://github.com/marketplace/actions/alls-green#why
|
||||
if: always()
|
||||
needs:
|
||||
- tests
|
||||
- cli-benchmark
|
||||
- brew-dry-run
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Decide whether the needed jobs succeeded or failed
|
||||
uses: re-actors/alls-green@05ac9388f0aebcb5727afa17fcccfecd6f8ec5fe # ratchet:re-actors/alls-green@release/v1
|
||||
with:
|
||||
jobs: ${{ toJSON(needs) }}
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
|
||||
cancel-in-progress: true
|
||||
# name: Continuous Integration
|
||||
# on:
|
||||
# workflow_call:
|
||||
# push:
|
||||
# branches: [main]
|
||||
# paths-ignore:
|
||||
# - 'docs/**'
|
||||
# - 'bazel/**'
|
||||
# - 'typings/**'
|
||||
# - '*.md'
|
||||
# - 'changelog.d/**'
|
||||
# - 'assets/**'
|
||||
# pull_request:
|
||||
# branches: [main]
|
||||
# paths-ignore:
|
||||
# - 'docs/**'
|
||||
# - 'bazel/**'
|
||||
# - 'typings/**'
|
||||
# - '*.md'
|
||||
# - 'changelog.d/**'
|
||||
# - 'assets/**'
|
||||
# env:
|
||||
# LINES: 120
|
||||
# COLUMNS: 120
|
||||
# OPENLLM_DO_NOT_TRACK: True
|
||||
# PYTHONUNBUFFERED: '1'
|
||||
# HATCH_VERBOSE: 2
|
||||
# # https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun
|
||||
# defaults:
|
||||
# run:
|
||||
# shell: bash --noprofile --norc -exo pipefail {0}
|
||||
# jobs:
|
||||
# tests:
|
||||
# runs-on: ${{ matrix.os }}
|
||||
# if: ${{ github.event_name == 'pull_request' || github.event_name == 'push'|| github.event_name == 'workflow_call' }}
|
||||
# strategy:
|
||||
# fail-fast: false
|
||||
# matrix:
|
||||
# os: [ubuntu-latest]
|
||||
# python-version: ['3.9', '3.12']
|
||||
# name: tests (${{ matrix.python-version }}.${{ matrix.os }})
|
||||
# steps:
|
||||
# - uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # ratchet:actions/checkout@v4.1.6
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
# ref: ${{ github.event.pull_request.head.sha }}
|
||||
# - uses: bentoml/setup-bentoml-action@862aa8fa0e0c3793fcca4bfe7a62717a497417e4 # ratchet:bentoml/setup-bentoml-action@v1
|
||||
# with:
|
||||
# bentoml-version: 'main'
|
||||
# python-version: ${{ matrix.python-version }}
|
||||
# - name: Run tests
|
||||
# run: hatch run tests:python
|
||||
# - name: Disambiguate coverage filename
|
||||
# run: mv .coverage ".coverage.${{ matrix.os }}.${{ matrix.python-version }}"
|
||||
# - name: Upload coverage data
|
||||
# uses: actions/upload-artifact@a8a3f3ad30e3422c9c7b888a15615d19a852ae32 # ratchet:actions/upload-artifact@v3
|
||||
# with:
|
||||
# name: coverage-data
|
||||
# path: .coverage.*
|
||||
# coverage:
|
||||
# name: report-coverage
|
||||
# runs-on: ubuntu-latest
|
||||
# if: false
|
||||
# needs: tests
|
||||
# steps:
|
||||
# - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4.1.1
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
# ref: ${{ github.event.pull_request.head.sha }}
|
||||
# - uses: bentoml/setup-bentoml-action@862aa8fa0e0c3793fcca4bfe7a62717a497417e4 # ratchet:bentoml/setup-bentoml-action@v1
|
||||
# with:
|
||||
# bentoml-version: 'main'
|
||||
# python-version-file: .python-version-default
|
||||
# - name: Download coverage data
|
||||
# uses: actions/download-artifact@9bc31d5ccc31df68ecc42ccf4149144866c47d8a # ratchet:actions/download-artifact@v3
|
||||
# with:
|
||||
# name: coverage-data
|
||||
# - name: Combine coverage data
|
||||
# run: hatch run coverage:combine
|
||||
# - name: Export coverage reports
|
||||
# run: |
|
||||
# hatch run coverage:report-xml openllm-python
|
||||
# hatch run coverage:report-uncovered-html openllm-python
|
||||
# - name: Upload uncovered HTML report
|
||||
# uses: actions/upload-artifact@a8a3f3ad30e3422c9c7b888a15615d19a852ae32 # ratchet:actions/upload-artifact@v3
|
||||
# with:
|
||||
# name: uncovered-html-report
|
||||
# path: htmlcov
|
||||
# - name: Generate coverage summary
|
||||
# run: hatch run coverage:generate-summary
|
||||
# - name: Write coverage summary report
|
||||
# if: github.event_name == 'pull_request'
|
||||
# run: hatch run coverage:write-summary-report
|
||||
# - name: Update coverage pull request comment
|
||||
# if: github.event_name == 'pull_request' && !github.event.pull_request.head.repo.fork
|
||||
# uses: marocchino/sticky-pull-request-comment@331f8f5b4215f0445d3c07b4967662a32a2d3e31 # ratchet:marocchino/sticky-pull-request-comment@v2
|
||||
# with:
|
||||
# path: coverage-report.md
|
||||
# cli-benchmark:
|
||||
# name: Check for CLI responsiveness
|
||||
# runs-on: ubuntu-latest
|
||||
# env:
|
||||
# HYPERFINE_VERSION: '1.12.0'
|
||||
# steps:
|
||||
# - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4.1.1
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
# - name: Install hyperfine
|
||||
# run: |
|
||||
# wget https://github.com/sharkdp/hyperfine/releases/download/v${HYPERFINE_VERSION}/hyperfine_${HYPERFINE_VERSION}_amd64.deb
|
||||
# sudo dpkg -i hyperfine_${HYPERFINE_VERSION}_amd64.deb
|
||||
# - uses: bentoml/setup-bentoml-action@862aa8fa0e0c3793fcca4bfe7a62717a497417e4 # ratchet:bentoml/setup-bentoml-action@v1
|
||||
# with:
|
||||
# bentoml-version: 'main'
|
||||
# python-version-file: .python-version-default
|
||||
# - name: Install self
|
||||
# run: bash local.sh
|
||||
# - name: Speed
|
||||
# run: hyperfine -m 100 --warmup 10 openllm
|
||||
# brew-dry-run:
|
||||
# name: Running dry-run tests for brew
|
||||
# runs-on: macos-latest
|
||||
# steps:
|
||||
# - name: Install tap and dry-run
|
||||
# run: |
|
||||
# brew tap bentoml/openllm https://github.com/bentoml/openllm
|
||||
# brew install openllm
|
||||
# openllm --help
|
||||
# openllm models --show-available
|
||||
# evergreen: # https://github.com/marketplace/actions/alls-green#why
|
||||
# if: always()
|
||||
# needs:
|
||||
# - tests
|
||||
# # - cli-benchmark
|
||||
# # - brew-dry-run
|
||||
# runs-on: ubuntu-latest
|
||||
# steps:
|
||||
# - name: Decide whether the needed jobs succeeded or failed
|
||||
# uses: re-actors/alls-green@05ac9388f0aebcb5727afa17fcccfecd6f8ec5fe # ratchet:re-actors/alls-green@release/v1
|
||||
# with:
|
||||
# jobs: ${{ toJSON(needs) }}
|
||||
# concurrency:
|
||||
# group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
|
||||
# cancel-in-progress: true
|
||||
|
||||
@@ -5,7 +5,7 @@ ci:
|
||||
autoupdate_commit_msg: 'ci: pre-commit autoupdate [pre-commit.ci]'
|
||||
autofix_prs: false
|
||||
default_language_version:
|
||||
python: python3.9 # NOTE: sync with .python-version-default
|
||||
python: python3.11 # NOTE: sync with .python-version-default
|
||||
exclude: '.*\.(css|js|svg)$'
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
|
||||
@@ -1 +1 @@
|
||||
3.9
|
||||
3.11
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from __future__ import annotations
|
||||
import uuid
|
||||
import uuid, os
|
||||
from typing import Any, AsyncGenerator, Dict, TypedDict, Union
|
||||
|
||||
from bentoml import Service
|
||||
from bentoml.io import JSON, Text
|
||||
from openllm import LLM
|
||||
|
||||
os.environ['IMPLEMENTATION'] = 'deprecated'
|
||||
|
||||
llm = LLM[Any, Any]('HuggingFaceH4/zephyr-7b-alpha', backend='vllm')
|
||||
|
||||
|
||||
|
||||
4
examples/async_client.py
Normal file
4
examples/async_client.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import openllm, asyncio
|
||||
client = openllm.AsyncHTTPClient('http://0.0.0.0:3000')
|
||||
async def main(): assert (await client.health()); print(await client.generate('Explain superconductor to a 5 year old.'))
|
||||
if __name__ == "__main__": asyncio.run(main())
|
||||
@@ -14,7 +14,7 @@ model = models.data[0].id
|
||||
# Chat completion API
|
||||
stream = str(os.getenv('STREAM', False)).upper() in ['TRUE', '1', 'YES', 'Y', 'ON']
|
||||
completions = client.chat.completions.create(messages=[
|
||||
ChatCompletionSystemMessageParam(role='system', content='You will be the writing assistant that assume the ton of Ernest Hemmingway.'),
|
||||
ChatCompletionSystemMessageParam(role='system', content='You will be the writing assistant that assume the tone of Ernest Hemmingway.'),
|
||||
ChatCompletionUserMessageParam(role='user', content='Write an essay on Nietzsche and absurdism.'),
|
||||
], model=model, max_tokens=1024, stream=stream)
|
||||
|
||||
|
||||
13
hatch.toml
13
hatch.toml
@@ -1,4 +1,6 @@
|
||||
[envs.default]
|
||||
installer = "uv"
|
||||
type = "virtual"
|
||||
dependencies = [
|
||||
"openllm-core @ {root:uri}/openllm-core",
|
||||
"openllm-client @ {root:uri}/openllm-client",
|
||||
@@ -29,12 +31,12 @@ setup = [
|
||||
quality = ["bash ./all.sh", "- pre-commit run --all-files", "- pnpm format"]
|
||||
tool = ["quality", "bash ./clean.sh", 'python ./cz.py']
|
||||
[envs.tests]
|
||||
installer = "uv"
|
||||
type = "virtual"
|
||||
dependencies = [
|
||||
"openllm[vllm] @ {root:uri}/openllm-python",
|
||||
"openllm-core @ {root:uri}/openllm-core",
|
||||
"openllm-client @ {root:uri}/openllm-client",
|
||||
"openllm[chatglm,fine-tune] @ {root:uri}/openllm-python",
|
||||
# NOTE: interact with docker for container tests.
|
||||
"docker",
|
||||
# NOTE: Tests strategies with Hypothesis and pytest, and snapshot testing with syrupy
|
||||
"coverage[toml]>=6.5",
|
||||
"filelock>=3.7.1",
|
||||
@@ -53,10 +55,7 @@ skip-install = false
|
||||
template = "tests"
|
||||
[envs.tests.scripts]
|
||||
_run_script = "pytest --cov --cov-report={env:COVERAGE_REPORT:term-missing} --cov-config=pyproject.toml -vv"
|
||||
distributed = "_run_script --reruns 5 --reruns-delay 3 --ignore openllm-python/tests/models -n 3 -r aR {args:openllm-python/tests}"
|
||||
models = "_run_script -s {args:openllm-python/tests/models}"
|
||||
python = "_run_script --reruns 5 --reruns-delay 3 --ignore openllm-python/tests/models -r aR {args:openllm-python/tests}"
|
||||
snapshot-models = "_run_script -s --snapshot-update {args:openllm-python/tests/models}"
|
||||
python = "_run_script -r aR -x {args:openllm-python/tests}"
|
||||
[envs.tests.overrides]
|
||||
env.GITHUB_ACTIONS.env-vars = "COVERAGE_REPORT="
|
||||
[envs.coverage]
|
||||
|
||||
@@ -52,7 +52,7 @@ keywords = [
|
||||
"PyTorch",
|
||||
"Transformers",
|
||||
]
|
||||
dependencies = ["openllm-core", "attrs>=23.2.0", "httpx", "distro", "anyio"]
|
||||
dependencies = ["openllm-core", "httpx", "distro", "anyio"]
|
||||
license = "Apache-2.0"
|
||||
name = "openllm-client"
|
||||
requires-python = ">=3.8"
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import importlib.metadata, logging, os, typing as t, orjson, pydantic
|
||||
|
||||
from ._schemas import Helpers, Metadata, Response, StreamingResponse
|
||||
from ._shim import MAX_RETRIES, AsyncClient, Client
|
||||
@@ -18,28 +14,31 @@ def _address_converter(addr: str):
|
||||
return addr if '://' in addr else 'http://' + addr
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class HTTPClient(Client):
|
||||
helpers: Helpers = attr.field(init=False)
|
||||
_api_version: str = 'v1'
|
||||
_verify: bool = True
|
||||
__metadata: Metadata | None = None
|
||||
__config: dict[str, t.Any] | None = None
|
||||
_helpers: Helpers = pydantic.PrivateAttr()
|
||||
_api_version: str = pydantic.PrivateAttr(default='v1')
|
||||
_verify: bool = pydantic.PrivateAttr(default=True)
|
||||
__metadata: t.Optional[Metadata] = None
|
||||
__config: t.Optional[t.Dict[str, t.Any]] = None
|
||||
|
||||
def __repr__(self):
|
||||
return f'<HTTPClient address={self.address} timeout={self._timeout} api_version={self._api_version} verify={self._verify}>'
|
||||
return f'<HTTPClient address={self.address} timeout={self.timeout} api_version={self._api_version} verify={self._verify}>'
|
||||
|
||||
def __init__(self, address=None, timeout=30, verify=False, max_retries=MAX_RETRIES, api_version='v1'):
|
||||
if address is None:
|
||||
address = os.getenv('OPENLLM_ENDPOINT')
|
||||
if address is None:
|
||||
raise ValueError("address must either be provided or through 'OPENLLM_ENDPOINT'")
|
||||
self._api_version, self._verify = api_version, verify
|
||||
|
||||
self.helpers = Helpers(client=self)
|
||||
|
||||
super().__init__(_address_converter(address), VERSION, timeout=timeout, max_retries=max_retries)
|
||||
|
||||
self._helpers = Helpers(client=self)
|
||||
self._api_version, self._verify = api_version, verify
|
||||
|
||||
@property
|
||||
def helpers(self):
|
||||
return self._helpers
|
||||
|
||||
def _build_auth_headers(self) -> t.Dict[str, str]:
|
||||
env = os.getenv('OPENLLM_AUTH_TOKEN')
|
||||
if env is not None:
|
||||
@@ -57,13 +56,13 @@ class HTTPClient(Client):
|
||||
def _metadata(self):
|
||||
if self.__metadata is None:
|
||||
path = f'/{self._api_version}/metadata'
|
||||
self.__metadata = self._post(path, response_cls=Metadata, json={}, options={'max_retries': self._max_retries})
|
||||
self.__metadata = self._post(path, response_cls=Metadata, json={}, options={'max_retries': self.max_retries})
|
||||
return self.__metadata
|
||||
|
||||
@property
|
||||
def _config(self) -> dict[str, t.Any]:
|
||||
if self.__config is None:
|
||||
self.__config = self._metadata.configuration
|
||||
self.__config = orjson.loads(self._metadata.configuration)
|
||||
return self.__config
|
||||
|
||||
def query(self, prompt, **attrs):
|
||||
@@ -71,7 +70,7 @@ class HTTPClient(Client):
|
||||
|
||||
def health(self):
|
||||
response = self._get(
|
||||
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}
|
||||
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self.max_retries}
|
||||
)
|
||||
return response.status_code == 200
|
||||
|
||||
@@ -79,7 +78,7 @@ class HTTPClient(Client):
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> Response:
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
timeout = self.timeout
|
||||
if verify is None:
|
||||
verify = self._verify # XXX: need to support this again
|
||||
if llm_config is not None:
|
||||
@@ -91,7 +90,7 @@ class HTTPClient(Client):
|
||||
f'/{self._api_version}/generate',
|
||||
response_cls=Response,
|
||||
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
|
||||
options={'max_retries': self._max_retries},
|
||||
options={'max_retries': self.max_retries},
|
||||
)
|
||||
|
||||
def generate_stream(
|
||||
@@ -104,7 +103,7 @@ class HTTPClient(Client):
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> t.Iterator[Response]:
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
timeout = self.timeout
|
||||
if verify is None:
|
||||
verify = self._verify # XXX: need to support this again
|
||||
if llm_config is not None:
|
||||
@@ -115,34 +114,37 @@ class HTTPClient(Client):
|
||||
f'/{self._api_version}/generate_stream',
|
||||
response_cls=Response,
|
||||
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
|
||||
options={'max_retries': self._max_retries},
|
||||
options={'max_retries': self.max_retries},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class AsyncHTTPClient(AsyncClient):
|
||||
helpers: Helpers = attr.field(init=False)
|
||||
_api_version: str = 'v1'
|
||||
_verify: bool = True
|
||||
__metadata: Metadata | None = None
|
||||
__config: dict[str, t.Any] | None = None
|
||||
class AsyncHTTPClient(AsyncClient, pydantic.BaseModel):
|
||||
_helpers: Helpers = pydantic.PrivateAttr()
|
||||
_api_version: str = pydantic.PrivateAttr(default='v1')
|
||||
_verify: bool = pydantic.PrivateAttr(default=True)
|
||||
__metadata: t.Optional[Metadata] = None
|
||||
__config: t.Optional[t.Dict[str, t.Any]] = None
|
||||
|
||||
def __repr__(self):
|
||||
return f'<AsyncHTTPClient address={self.address} timeout={self._timeout} api_version={self._api_version} verify={self._verify}>'
|
||||
return f'<AsyncHTTPClient address={self.address} timeout={self.timeout} api_version={self._api_version} verify={self._verify}>'
|
||||
|
||||
def __init__(self, address=None, timeout=30, verify=False, max_retries=MAX_RETRIES, api_version='v1'):
|
||||
if address is None:
|
||||
address = os.getenv('OPENLLM_ENDPOINT')
|
||||
if address is None:
|
||||
raise ValueError("address must either be provided or through 'OPENLLM_ENDPOINT'")
|
||||
self._api_version, self._verify = api_version, verify
|
||||
|
||||
# mk messages to be async here
|
||||
self.helpers = Helpers.permute(messages=Helpers.async_messages)(async_client=self)
|
||||
|
||||
super().__init__(_address_converter(address), VERSION, timeout=timeout, max_retries=max_retries)
|
||||
|
||||
# mk messages to be async here
|
||||
self._helpers = Helpers.permute(messages=Helpers.async_messages)(async_client=self)
|
||||
self._api_version, self._verify = api_version, verify
|
||||
|
||||
@property
|
||||
def helpers(self):
|
||||
return self._helpers
|
||||
|
||||
def _build_auth_headers(self) -> t.Dict[str, str]:
|
||||
env = os.getenv('OPENLLM_AUTH_TOKEN')
|
||||
if env is not None:
|
||||
@@ -153,14 +155,14 @@ class AsyncHTTPClient(AsyncClient):
|
||||
async def _metadata(self) -> t.Awaitable[Metadata]:
|
||||
if self.__metadata is None:
|
||||
self.__metadata = await self._post(
|
||||
f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries}
|
||||
f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self.max_retries}
|
||||
)
|
||||
return self.__metadata
|
||||
|
||||
@property
|
||||
async def _config(self):
|
||||
if self.__config is None:
|
||||
self.__config = (await self._metadata).configuration
|
||||
self.__config = orjson.loads((await self._metadata).configuration)
|
||||
return self.__config
|
||||
|
||||
async def query(self, prompt, **attrs):
|
||||
@@ -168,7 +170,7 @@ class AsyncHTTPClient(AsyncClient):
|
||||
|
||||
async def health(self):
|
||||
response = await self._get(
|
||||
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}
|
||||
'/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self.max_retries}
|
||||
)
|
||||
return response.status_code == 200
|
||||
|
||||
@@ -176,7 +178,7 @@ class AsyncHTTPClient(AsyncClient):
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> Response:
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
timeout = self.timeout
|
||||
if verify is None:
|
||||
verify = self._verify # XXX: need to support this again
|
||||
_metadata = await self._metadata
|
||||
@@ -189,22 +191,22 @@ class AsyncHTTPClient(AsyncClient):
|
||||
f'/{self._api_version}/generate',
|
||||
response_cls=Response,
|
||||
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
|
||||
options={'max_retries': self._max_retries},
|
||||
options={'max_retries': self.max_retries},
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, index=0, **attrs
|
||||
) -> t.AsyncGenerator[StreamingResponse, t.Any]:
|
||||
async for response_chunk in self.generate_iterator(
|
||||
prompt, llm_config, stop, adapter_name, timeout, verify, **attrs
|
||||
):
|
||||
yield StreamingResponse.from_response_chunk(response_chunk)
|
||||
yield StreamingResponse.from_response_chunk(response_chunk, index=index)
|
||||
|
||||
async def generate_iterator(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> t.AsyncGenerator[Response, t.Any]:
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
timeout = self.timeout
|
||||
if verify is None:
|
||||
verify = self._verify # XXX: need to support this again
|
||||
_metadata = await self._metadata
|
||||
@@ -218,7 +220,7 @@ class AsyncHTTPClient(AsyncClient):
|
||||
f'/{self._api_version}/generate_stream',
|
||||
response_cls=Response,
|
||||
json=dict(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name),
|
||||
options={'max_retries': self._max_retries},
|
||||
options={'max_retries': self.max_retries},
|
||||
stream=True,
|
||||
):
|
||||
yield response_chunk
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
from __future__ import annotations
|
||||
import types
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import orjson
|
||||
import types, pydantic, typing as t
|
||||
|
||||
from openllm_core._schemas import (
|
||||
CompletionChunk as CompletionChunk,
|
||||
GenerationOutput as Response, # backward compatibility
|
||||
_SchemaMixin as _SchemaMixin,
|
||||
MetadataOutput as Metadata,
|
||||
GenerationOutput as Response,
|
||||
)
|
||||
|
||||
from ._utils import converter
|
||||
from ._typing_compat import TypedDict
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ._shim import AsyncClient, Client
|
||||
@@ -20,65 +16,28 @@ if t.TYPE_CHECKING:
|
||||
__all__ = ['CompletionChunk', 'Helpers', 'Metadata', 'Response', 'StreamingResponse']
|
||||
|
||||
|
||||
@attr.define
|
||||
class Metadata(_SchemaMixin):
|
||||
"""NOTE: Metadata is a modified version of the original MetadataOutput from openllm-core.
|
||||
|
||||
The configuration is now structured into a dictionary for easy of use."""
|
||||
|
||||
model_id: str
|
||||
timeout: int
|
||||
model_name: str
|
||||
backend: str
|
||||
configuration: t.Dict[str, t.Any]
|
||||
|
||||
|
||||
def _structure_metadata(data: t.Dict[str, t.Any], cls: type[Metadata]) -> Metadata:
|
||||
try:
|
||||
configuration = orjson.loads(data['configuration'])
|
||||
generation_config = configuration.pop('generation_config')
|
||||
configuration = {**configuration, **generation_config}
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise RuntimeError(f'Malformed metadata configuration (Server-side issue): {e}') from None
|
||||
try:
|
||||
return cls(
|
||||
model_id=data['model_id'],
|
||||
timeout=data['timeout'],
|
||||
model_name=data['model_name'],
|
||||
backend=data['backend'],
|
||||
configuration=configuration,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Malformed metadata (Server-side issue): {e}') from None
|
||||
|
||||
|
||||
converter.register_structure_hook(Metadata, _structure_metadata)
|
||||
|
||||
|
||||
@attr.define
|
||||
class StreamingResponse(_SchemaMixin):
|
||||
class StreamingResponse(pydantic.BaseModel):
|
||||
request_id: str
|
||||
index: int
|
||||
text: str
|
||||
token_ids: int
|
||||
|
||||
@classmethod
|
||||
def from_response_chunk(cls, response: Response) -> StreamingResponse:
|
||||
def from_response_chunk(cls, response: Response, index: int = 0) -> StreamingResponse:
|
||||
return cls(
|
||||
request_id=response.request_id,
|
||||
index=response.outputs[0].index,
|
||||
text=response.outputs[0].text,
|
||||
token_ids=response.outputs[0].token_ids[0],
|
||||
index=response.outputs[index].index,
|
||||
text=response.outputs[index].text,
|
||||
token_ids=response.outputs[index].token_ids[0],
|
||||
)
|
||||
|
||||
|
||||
class MesssageParam(t.TypedDict):
|
||||
class MesssageParam(TypedDict):
|
||||
role: t.Literal['user', 'system', 'assistant']
|
||||
content: str
|
||||
|
||||
|
||||
@attr.define(repr=False)
|
||||
class Helpers:
|
||||
class Helpers(pydantic.BaseModel):
|
||||
_client: t.Optional[Client] = None
|
||||
_async_client: t.Optional[AsyncClient] = None
|
||||
|
||||
|
||||
@@ -1,22 +1,11 @@
|
||||
# This provides a base shim with httpx and acts as base request
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import email.utils
|
||||
import logging
|
||||
import platform
|
||||
import random
|
||||
import time
|
||||
import typing as t
|
||||
import asyncio, logging, platform, random, time, typing as t
|
||||
import email.utils, anyio, distro, httpx, pydantic
|
||||
|
||||
import anyio
|
||||
import attr
|
||||
import distro
|
||||
import httpx
|
||||
|
||||
from ._stream import AsyncStream, Response, Stream
|
||||
from ._stream import AsyncStream, Stream, Response
|
||||
from ._typing_compat import Annotated, Architecture, LiteralString, Platform
|
||||
from ._utils import converter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -92,15 +81,16 @@ def _architecture() -> Architecture:
|
||||
|
||||
|
||||
@t.final
|
||||
@attr.frozen(auto_attribs=True)
|
||||
class RequestOptions:
|
||||
method: str = attr.field(converter=str.lower)
|
||||
class RequestOptions(pydantic.BaseModel):
|
||||
model_config = pydantic.ConfigDict(extra='forbid', protected_namespaces=())
|
||||
|
||||
method: pydantic.constr(to_lower=True)
|
||||
url: str
|
||||
json: t.Optional[t.Dict[str, t.Any]] = attr.field(default=None)
|
||||
params: t.Optional[t.Mapping[str, t.Any]] = attr.field(default=None)
|
||||
headers: t.Optional[t.Dict[str, str]] = attr.field(default=None)
|
||||
max_retries: int = attr.field(default=MAX_RETRIES)
|
||||
return_raw_response: bool = attr.field(default=False)
|
||||
json: t.Optional[t.Dict[str, t.Any]] = pydantic.Field(default=None)
|
||||
params: t.Optional[t.Mapping[str, t.Any]] = pydantic.Field(default=None)
|
||||
headers: t.Optional[t.Dict[str, str]] = pydantic.Field(default=None)
|
||||
max_retries: int = pydantic.Field(default=MAX_RETRIES)
|
||||
return_raw_response: bool = pydantic.Field(default=False)
|
||||
|
||||
def get_max_retries(self, max_retries: int | None) -> int:
|
||||
return max_retries if max_retries is not None else self.max_retries
|
||||
@@ -111,95 +101,97 @@ class RequestOptions:
|
||||
|
||||
|
||||
@t.final
|
||||
@attr.frozen(auto_attribs=True)
|
||||
class APIResponse(t.Generic[Response]):
|
||||
_raw_response: httpx.Response
|
||||
_client: BaseClient[t.Any, t.Any]
|
||||
_response_cls: type[Response]
|
||||
_stream: bool
|
||||
_stream_cls: t.Optional[t.Union[t.Type[Stream[t.Any]], t.Type[AsyncStream[t.Any]]]]
|
||||
_options: RequestOptions
|
||||
class APIResponse(pydantic.BaseModel, t.Generic[Response]):
|
||||
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
_parsed: t.Optional[Response] = attr.field(default=None, repr=False)
|
||||
raw_response: httpx.Response
|
||||
client: t.Union[AsyncClient, Client]
|
||||
response_cls: t.Optional[type[Response]]
|
||||
stream: bool
|
||||
stream_cls: t.Optional[t.Union[t.Type[Stream[t.Any]], t.Type[AsyncStream[t.Any]]]]
|
||||
options: RequestOptions
|
||||
_parsed: t.Optional[Response] = pydantic.PrivateAttr(default=None)
|
||||
|
||||
def parse(self):
|
||||
if self._options.return_raw_response:
|
||||
return self._raw_response
|
||||
if self.options.return_raw_response:
|
||||
return self.raw_response
|
||||
|
||||
if self.response_cls is None:
|
||||
raise ValueError('Response class cannot be None.')
|
||||
|
||||
if self._parsed is not None:
|
||||
return self._parsed
|
||||
if self._stream:
|
||||
stream_cls = self._stream_cls or self._client._default_stream_cls
|
||||
return stream_cls(response_cls=self._response_cls, response=self._raw_response, client=self._client)
|
||||
|
||||
if self._response_cls is str:
|
||||
return self._raw_response.text
|
||||
if self.stream:
|
||||
stream_cls = self.stream_cls or self.client.default_stream_cls
|
||||
return stream_cls(response_cls=self.response_cls, response=self.raw_response, client=self.client)
|
||||
|
||||
content_type, *_ = self._raw_response.headers.get('content-type', '').split(';')
|
||||
if self.response_cls is str:
|
||||
return self.raw_response.text
|
||||
|
||||
content_type, *_ = self.raw_response.headers.get('content-type', '').split(';')
|
||||
if content_type != 'application/json':
|
||||
# Since users specific different content_type, then we return the raw binary text without and deserialisation
|
||||
return self._raw_response.text
|
||||
return self.raw_response.text
|
||||
|
||||
data = self._raw_response.json()
|
||||
data = self.raw_response.json()
|
||||
try:
|
||||
return self._client._process_response_data(
|
||||
data=data, response_cls=self._response_cls, raw_response=self._raw_response
|
||||
return self.client._process_response_data(
|
||||
data=data, response_cls=self.response_cls, raw_response=self.raw_response
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError(exc) from None # validation error here
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
return self._raw_response.headers
|
||||
return self.raw_response.headers
|
||||
|
||||
@property
|
||||
def status_code(self):
|
||||
return self._raw_response.status_code
|
||||
return self.raw_response.status_code
|
||||
|
||||
@property
|
||||
def request(self):
|
||||
return self._raw_response.request
|
||||
return self.raw_response.request
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
return self._raw_response.url
|
||||
return self.raw_response.url
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
return self._raw_response.content
|
||||
return self.raw_response.content
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return self._raw_response.text
|
||||
return self.raw_response.text
|
||||
|
||||
@property
|
||||
def http_version(self):
|
||||
return self._raw_response.http_version
|
||||
return self.raw_response.http_version
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class BaseClient(t.Generic[InnerClient, StreamType]):
|
||||
_base_url: httpx.URL = attr.field(converter=_address_converter)
|
||||
_version: LiteralVersion
|
||||
_timeout: httpx.Timeout = attr.field(converter=httpx.Timeout)
|
||||
_max_retries: int
|
||||
_inner: InnerClient
|
||||
_default_stream_cls: t.Type[StreamType]
|
||||
_auth_headers: t.Dict[str, str] = attr.field(init=False)
|
||||
class BaseClient(pydantic.BaseModel, t.Generic[InnerClient, StreamType]):
|
||||
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str | httpx.URL,
|
||||
version: str,
|
||||
timeout: int | httpx.Timeout = DEFAULT_TIMEOUT,
|
||||
max_retries: int = MAX_RETRIES,
|
||||
client: InnerClient,
|
||||
_default_stream_cls: t.Type[StreamType],
|
||||
_internal: bool = False,
|
||||
):
|
||||
if not _internal:
|
||||
raise RuntimeError('Client is reserved to be used internally only.')
|
||||
self.__attrs_init__(base_url, version, timeout, max_retries, client, _default_stream_cls)
|
||||
_base_url: httpx.URL = pydantic.PrivateAttr()
|
||||
version: pydantic.SkipValidation[LiteralVersion]
|
||||
timeout: httpx.Timeout
|
||||
max_retries: int
|
||||
inner: InnerClient
|
||||
default_stream_cls: pydantic.SkipValidation[t.Type[StreamType]]
|
||||
_auth_headers: t.Dict[str, str] = pydantic.PrivateAttr()
|
||||
|
||||
@pydantic.field_validator('timeout', mode='before')
|
||||
@classmethod
|
||||
def convert_timeout(cls, value: t.Any) -> httpx.Timeout:
|
||||
return httpx.Timeout(value)
|
||||
|
||||
def __init__(self, base_url: str, **data: t.Any):
|
||||
super().__init__(**data)
|
||||
self._base_url = _address_converter(base_url)
|
||||
|
||||
def model_post_init(self, *_: t.Any):
|
||||
self._auth_headers = self._build_auth_headers()
|
||||
|
||||
def _prepare_url(self, url: str) -> httpx.URL:
|
||||
@@ -212,7 +204,7 @@ class BaseClient(t.Generic[InnerClient, StreamType]):
|
||||
|
||||
@property
|
||||
def is_closed(self):
|
||||
return self._inner.is_closed
|
||||
return self.inner.is_closed
|
||||
|
||||
@property
|
||||
def is_ready(self):
|
||||
@@ -239,7 +231,7 @@ class BaseClient(t.Generic[InnerClient, StreamType]):
|
||||
|
||||
@property
|
||||
def user_agent(self):
|
||||
return f'{self.__class__.__name__}/Python {self._version}'
|
||||
return f'{self.__class__.__name__}/Python {self.version}'
|
||||
|
||||
@property
|
||||
def auth_headers(self):
|
||||
@@ -258,7 +250,7 @@ class BaseClient(t.Generic[InnerClient, StreamType]):
|
||||
@property
|
||||
def platform_headers(self):
|
||||
return {
|
||||
'X-OpenLLM-Client-Package-Version': self._version,
|
||||
'X-OpenLLM-Client-Package-Version': self.version,
|
||||
'X-OpenLLM-Client-Language': 'Python',
|
||||
'X-OpenLLM-Client-Runtime': platform.python_implementation(),
|
||||
'X-OpenLLM-Client-Runtime-Version': platform.python_version(),
|
||||
@@ -267,13 +259,13 @@ class BaseClient(t.Generic[InnerClient, StreamType]):
|
||||
}
|
||||
|
||||
def _remaining_retries(self, remaining_retries: int | None, options: RequestOptions) -> int:
|
||||
return remaining_retries if remaining_retries is not None else options.get_max_retries(self._max_retries)
|
||||
return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries)
|
||||
|
||||
def _build_headers(self, options: RequestOptions) -> httpx.Headers:
|
||||
return httpx.Headers(_merge_mapping(self._default_headers, options.headers or {}))
|
||||
|
||||
def _build_request(self, options: RequestOptions) -> httpx.Request:
|
||||
return self._inner.build_request(
|
||||
return self.inner.build_request(
|
||||
method=options.method,
|
||||
headers=self._build_headers(options),
|
||||
url=self._prepare_url(options.url),
|
||||
@@ -284,7 +276,7 @@ class BaseClient(t.Generic[InnerClient, StreamType]):
|
||||
def _calculate_retry_timeout(
|
||||
self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None
|
||||
) -> float:
|
||||
max_retries = options.get_max_retries(self._max_retries)
|
||||
max_retries = options.get_max_retries(self.max_retries)
|
||||
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
|
||||
try:
|
||||
if headers is not None:
|
||||
@@ -327,7 +319,7 @@ class BaseClient(t.Generic[InnerClient, StreamType]):
|
||||
def _process_response_data(
|
||||
self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response
|
||||
) -> Response:
|
||||
return converter.structure(data, response_cls)
|
||||
return response_cls(**data)
|
||||
|
||||
def _process_response(
|
||||
self,
|
||||
@@ -348,7 +340,6 @@ class BaseClient(t.Generic[InnerClient, StreamType]):
|
||||
).parse()
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class Client(BaseClient[httpx.Client, Stream[t.Any]]):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -362,18 +353,17 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]):
|
||||
version=version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=httpx.Client(base_url=base_url, timeout=timeout),
|
||||
_default_stream_cls=Stream,
|
||||
_internal=True,
|
||||
inner=httpx.Client(base_url=base_url, timeout=timeout),
|
||||
default_stream_cls=Stream,
|
||||
)
|
||||
|
||||
def close(self):
|
||||
self._inner.close()
|
||||
self.inner.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
def __exit__(self, *args: t.Any) -> None:
|
||||
self.close()
|
||||
|
||||
def __del__(self):
|
||||
@@ -408,7 +398,7 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]):
|
||||
retries = self._remaining_retries(remaining_retries, options)
|
||||
request = self._build_request(options)
|
||||
try:
|
||||
response = self._inner.send(request, auth=self.auth, stream=stream)
|
||||
response = self.inner.send(request, auth=self.auth, stream=stream)
|
||||
logger.debug('HTTP [%s, %s]: %i [%s]', request.method, request.url, response.status_code, response.reason_phrase)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
@@ -418,7 +408,7 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]):
|
||||
)
|
||||
# If the response is streamed then we need to explicitly read the completed response
|
||||
exc.response.read()
|
||||
raise ValueError(exc.message) from None
|
||||
raise ValueError(exc) from None
|
||||
except httpx.TimeoutException:
|
||||
if retries > 0:
|
||||
return self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
|
||||
@@ -481,7 +471,6 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]):
|
||||
)
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -495,18 +484,17 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
|
||||
version=version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=httpx.AsyncClient(base_url=base_url, timeout=timeout),
|
||||
_default_stream_cls=AsyncStream,
|
||||
_internal=True,
|
||||
inner=httpx.AsyncClient(base_url=base_url, timeout=timeout),
|
||||
default_stream_cls=AsyncStream,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
await self._inner.aclose()
|
||||
await self.inner.aclose()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args) -> None:
|
||||
async def __aexit__(self, *args: t.Any) -> None:
|
||||
await self.close()
|
||||
|
||||
def __del__(self):
|
||||
@@ -545,7 +533,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
|
||||
request = self._build_request(options)
|
||||
|
||||
try:
|
||||
response = await self._inner.send(request, auth=self.auth, stream=stream)
|
||||
response = await self.inner.send(request, auth=self.auth, stream=stream)
|
||||
logger.debug('HTTP [%s, %s]: %i [%s]', request.method, request.url, response.status_code, response.reason_phrase)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
@@ -555,7 +543,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
|
||||
)
|
||||
# If the response is streamed then we need to explicitly read the completed response
|
||||
await exc.response.aread()
|
||||
raise ValueError(exc.message) from None
|
||||
raise ValueError(exc) from None
|
||||
except httpx.ConnectTimeout as err:
|
||||
if retries > 0:
|
||||
return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls)
|
||||
@@ -622,3 +610,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]):
|
||||
return await self.request(
|
||||
response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls
|
||||
)
|
||||
|
||||
|
||||
Stream.model_rebuild()
|
||||
AsyncStream.model_rebuild()
|
||||
|
||||
@@ -1,26 +1,23 @@
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
import httpx
|
||||
import orjson
|
||||
import pydantic, httpx, orjson, typing as t
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ._shim import AsyncClient, Client
|
||||
|
||||
Response = t.TypeVar('Response', bound=attr.AttrsInstance)
|
||||
Response = t.TypeVar('Response', bound=pydantic.BaseModel)
|
||||
|
||||
|
||||
@attr.define(auto_attribs=True)
|
||||
class Stream(t.Generic[Response]):
|
||||
_response_cls: t.Type[Response]
|
||||
_response: httpx.Response
|
||||
_client: Client
|
||||
_decoder: SSEDecoder = attr.field(factory=lambda: SSEDecoder())
|
||||
_iterator: t.Iterator[Response] = attr.field(init=False)
|
||||
class Stream(pydantic.BaseModel, t.Generic[Response]):
|
||||
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
||||
response_cls: t.Type[Response]
|
||||
response: pydantic.SkipValidation[httpx.Response]
|
||||
client: Client
|
||||
_decoder: SSEDecoder = pydantic.PrivateAttr(default_factory=lambda: SSEDecoder())
|
||||
_iterator: t.Iterator[Response] = pydantic.PrivateAttr()
|
||||
|
||||
def __init__(self, response_cls, response, client):
|
||||
self.__attrs_init__(response_cls, response, client)
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._iterator = self._stream()
|
||||
|
||||
def __next__(self):
|
||||
@@ -31,28 +28,28 @@ class Stream(t.Generic[Response]):
|
||||
yield item
|
||||
|
||||
def _iter_events(self) -> t.Iterator[SSE]:
|
||||
yield from self._decoder.iter(self._response.iter_lines())
|
||||
yield from self._decoder.iter(self.response.iter_lines())
|
||||
|
||||
def _stream(self) -> t.Iterator[Response]:
|
||||
for sse in self._iter_events():
|
||||
if sse.data.startswith('[DONE]'):
|
||||
break
|
||||
if sse.event is None:
|
||||
yield self._client._process_response_data(
|
||||
data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response
|
||||
yield self.client._process_response_data(
|
||||
data=orjson.loads(sse.data), response_cls=self.response_cls, raw_response=self.response
|
||||
)
|
||||
|
||||
|
||||
@attr.define(auto_attribs=True)
|
||||
class AsyncStream(t.Generic[Response]):
|
||||
_response_cls: t.Type[Response]
|
||||
_response: httpx.Response
|
||||
_client: AsyncClient
|
||||
_decoder: SSEDecoder = attr.field(factory=lambda: SSEDecoder())
|
||||
_iterator: t.Iterator[Response] = attr.field(init=False)
|
||||
class AsyncStream(pydantic.BaseModel, t.Generic[Response]):
|
||||
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
||||
response_cls: t.Type[Response]
|
||||
response: pydantic.SkipValidation[httpx.Response]
|
||||
client: AsyncClient
|
||||
_decoder: SSEDecoder = pydantic.PrivateAttr(default_factory=lambda: SSEDecoder())
|
||||
_iterator: t.Iterator[Response] = pydantic.PrivateAttr()
|
||||
|
||||
def __init__(self, response_cls, response, client):
|
||||
self.__attrs_init__(response_cls, response, client)
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._iterator = self._stream()
|
||||
|
||||
async def __anext__(self):
|
||||
@@ -63,7 +60,7 @@ class AsyncStream(t.Generic[Response]):
|
||||
yield item
|
||||
|
||||
async def _iter_events(self):
|
||||
async for sse in self._decoder.aiter(self._response.aiter_lines()):
|
||||
async for sse in self._decoder.aiter(self.response.aiter_lines()):
|
||||
yield sse
|
||||
|
||||
async def _stream(self) -> t.AsyncGenerator[Response, None]:
|
||||
@@ -71,31 +68,23 @@ class AsyncStream(t.Generic[Response]):
|
||||
if sse.data.startswith('[DONE]'):
|
||||
break
|
||||
if sse.event is None:
|
||||
yield self._client._process_response_data(
|
||||
data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response
|
||||
yield self.client._process_response_data(
|
||||
data=orjson.loads(sse.data), response_cls=self.response_cls, raw_response=self.response
|
||||
)
|
||||
|
||||
|
||||
@attr.define
|
||||
class SSE:
|
||||
data: str = attr.field(default='')
|
||||
id: t.Optional[str] = attr.field(default=None)
|
||||
event: t.Optional[str] = attr.field(default=None)
|
||||
retry: t.Optional[int] = attr.field(default=None)
|
||||
|
||||
def model_dump(self) -> t.Dict[str, t.Any]:
|
||||
try:
|
||||
return orjson.loads(self.data)
|
||||
except orjson.JSONDecodeError:
|
||||
raise
|
||||
class SSE(pydantic.BaseModel):
|
||||
data: str = pydantic.Field(default='')
|
||||
id: t.Optional[str] = pydantic.Field(default=None)
|
||||
event: t.Optional[str] = pydantic.Field(default=None)
|
||||
retry: t.Optional[int] = pydantic.Field(default=None)
|
||||
|
||||
|
||||
@attr.define(auto_attribs=True)
|
||||
class SSEDecoder:
|
||||
_data: t.List[str] = attr.field(factory=list)
|
||||
_event: t.Optional[str] = None
|
||||
_retry: t.Optional[int] = None
|
||||
_last_event_id: t.Optional[str] = None
|
||||
class SSEDecoder(pydantic.BaseModel):
|
||||
_data: t.List[str] = pydantic.PrivateAttr(default_factory=list)
|
||||
event: t.Optional[str] = None
|
||||
retry: t.Optional[int] = None
|
||||
last_event_id: t.Optional[str] = None
|
||||
|
||||
def iter(self, iterator: t.Iterator[str]) -> t.Iterator[SSE]:
|
||||
for line in iterator:
|
||||
@@ -112,10 +101,10 @@ class SSEDecoder:
|
||||
def decode(self, line: str) -> SSE | None:
|
||||
# NOTE: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation
|
||||
if not line:
|
||||
if all(not a for a in [self._event, self._data, self._retry, self._last_event_id]):
|
||||
if all(not a for a in [self.event, self._data, self.retry, self.last_event_id]):
|
||||
return None
|
||||
sse = SSE(data='\n'.join(self._data), event=self._event, retry=self._retry, id=self._last_event_id)
|
||||
self._event, self._data, self._retry = None, [], None
|
||||
sse = SSE(data='\n'.join(self._data), event=self.event, retry=self.retry, id=self.last_event_id)
|
||||
self.event, self._data, self.retry = None, [], None
|
||||
return sse
|
||||
if line.startswith(':'):
|
||||
return None
|
||||
@@ -123,17 +112,17 @@ class SSEDecoder:
|
||||
if value.startswith(' '):
|
||||
value = value[1:]
|
||||
if field == 'event':
|
||||
self._event = value
|
||||
self.event = value
|
||||
elif field == 'data':
|
||||
self._data.append(value)
|
||||
elif field == 'id':
|
||||
if '\0' in value:
|
||||
pass
|
||||
else:
|
||||
self._last_event_id = value
|
||||
self.last_event_id = value
|
||||
elif field == 'retry':
|
||||
try:
|
||||
self._retry = int(value)
|
||||
self.retry = int(value)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
|
||||
@@ -7,6 +7,7 @@ from openllm_core._typing_compat import (
|
||||
Required as Required,
|
||||
Self as Self,
|
||||
TypeGuard as TypeGuard,
|
||||
TypedDict as TypedDict,
|
||||
dataclass_transform as dataclass_transform,
|
||||
overload as overload,
|
||||
)
|
||||
|
||||
2
openllm-python/README.md
generated
2
openllm-python/README.md
generated
@@ -27,7 +27,7 @@ OpenLLM helps developers **run any open-source LLMs**, such as Llama 2 and Mistr
|
||||
- 🚂 Support a wide range of open-source LLMs including LLMs fine-tuned with your own data
|
||||
- ⛓️ OpenAI compatible API endpoints for seamless transition from your LLM app to open-source LLMs
|
||||
- 🔥 State-of-the-art serving and inference performance
|
||||
- 🎯 Simplified cloud deployment via [BentoML](www.bentoml.com)
|
||||
- 🎯 Simplified cloud deployment via [BentoML](https://www.bentoml.com)
|
||||
|
||||
|
||||
<!-- hatch-fancy-pypi-readme intro stop -->
|
||||
|
||||
@@ -150,45 +150,6 @@ only-include = ["src/openllm", "src/openllm_cli", "src/_openllm_tiny"]
|
||||
sources = ["src"]
|
||||
[tool.hatch.build.targets.sdist]
|
||||
exclude = ["/.git_archival.txt", "tests", "/.python-version-default"]
|
||||
[tool.hatch.build.targets.wheel.hooks.mypyc]
|
||||
dependencies = [
|
||||
"hatch-mypyc==0.16.0",
|
||||
"mypy==1.7.0",
|
||||
# avoid https://github.com/pallets/click/issues/2558
|
||||
"click==8.1.3",
|
||||
"bentoml==1.1.9",
|
||||
"transformers>=4.32.1",
|
||||
"pandas-stubs",
|
||||
"types-psutil",
|
||||
"types-tabulate",
|
||||
"types-PyYAML",
|
||||
"types-protobuf",
|
||||
]
|
||||
enable-by-default = false
|
||||
exclude = ["src/_openllm_tiny/_service.py", "src/openllm/utils/__init__.py"]
|
||||
include = [
|
||||
"src/openllm/__init__.py",
|
||||
"src/openllm/_quantisation.py",
|
||||
"src/openllm/_generation.py",
|
||||
"src/openllm/exceptions.py",
|
||||
"src/openllm/testing.py",
|
||||
"src/openllm/utils",
|
||||
]
|
||||
# NOTE: This is consistent with pyproject.toml
|
||||
mypy-args = [
|
||||
"--strict",
|
||||
# this is because all transient library doesn't have types
|
||||
"--follow-imports=skip",
|
||||
"--allow-subclassing-any",
|
||||
"--check-untyped-defs",
|
||||
"--ignore-missing-imports",
|
||||
"--no-warn-return-any",
|
||||
"--warn-unreachable",
|
||||
"--no-warn-no-return",
|
||||
"--no-warn-unused-ignores",
|
||||
]
|
||||
options = { verbose = true, strip_asserts = true, debug_level = "2", opt_level = "3", include_runtime_files = true }
|
||||
require-runtime-dependencies = true
|
||||
[tool.hatch.metadata.hooks.fancy-pypi-readme]
|
||||
content-type = "text/markdown"
|
||||
# PyPI doesn't support the <picture> tag.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
if __name__ == '__main__':
|
||||
from openllm_cli.entrypoint import cli
|
||||
from _openllm_tiny._entrypoint import cli
|
||||
|
||||
cli()
|
||||
|
||||
@@ -70,7 +70,7 @@ def error_response(status_code, message):
|
||||
)
|
||||
|
||||
|
||||
async def check_model(request, model):
|
||||
async def check_model(request, model): # noqa
|
||||
if request.model == model:
|
||||
return None
|
||||
return error_response(
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
|
||||
from hypothesis import HealthCheck, settings
|
||||
|
||||
settings.register_profile('CI', settings(suppress_health_check=[HealthCheck.too_slow]), deadline=None)
|
||||
|
||||
if 'CI' in os.environ:
|
||||
settings.load_profile('CI')
|
||||
|
||||
78
openllm-python/tests/_data.py
Normal file
78
openllm-python/tests/_data.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from openllm_core._typing_compat import TypedDict
|
||||
from datasets import load_dataset
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
FIXED_OUTPUT_LENGTH = 128
|
||||
|
||||
|
||||
class DatasetEntry(TypedDict):
|
||||
human: str
|
||||
gpt: str
|
||||
|
||||
|
||||
class SampledRequest(TypedDict):
|
||||
prompt: str
|
||||
prompt_length: int
|
||||
output_length: int
|
||||
|
||||
|
||||
def prepare_sharegpt_request(
|
||||
num_requests: int, tokenizer: PreTrainedTokenizerBase, max_output_length: int | None = None
|
||||
) -> list[SampledRequest]:
|
||||
def transform(examples) -> DatasetEntry:
|
||||
human, gpt = [], []
|
||||
for example in examples['conversations']:
|
||||
human.append(example[0]['value'])
|
||||
gpt.append(example[1]['value'])
|
||||
return {'human': human, 'gpt': gpt}
|
||||
|
||||
def process(examples, tokenizer, max_output_length: t.Optional[int]):
|
||||
# Tokenize the 'human' and 'gpt' values in batches
|
||||
prompt_token_ids = tokenizer(examples['human']).input_ids
|
||||
completion_token_ids = tokenizer(examples['gpt']).input_ids
|
||||
|
||||
# Create the transformed entries
|
||||
return {
|
||||
'prompt': examples['human'],
|
||||
'prompt_length': [len(ids) for ids in prompt_token_ids],
|
||||
'output_length': [
|
||||
len(ids) if max_output_length is None else FIXED_OUTPUT_LENGTH for ids in completion_token_ids
|
||||
],
|
||||
}
|
||||
|
||||
def filter_length(examples) -> list[bool]:
|
||||
result = []
|
||||
for prompt_length, output_length in zip(examples['prompt_length'], examples['output_length']):
|
||||
if prompt_length < 4 or output_length < 4:
|
||||
result.append(False)
|
||||
elif prompt_length > 1024 or prompt_length + output_length > 2048:
|
||||
result.append(False)
|
||||
else:
|
||||
result.append(True)
|
||||
return result
|
||||
|
||||
return (
|
||||
(
|
||||
dataset := load_dataset(
|
||||
'anon8231489123/ShareGPT_Vicuna_unfiltered',
|
||||
data_files='ShareGPT_V3_unfiltered_cleaned_split.json',
|
||||
split='train',
|
||||
)
|
||||
)
|
||||
.filter(lambda example: len(example['conversations']) >= 2, num_proc=8)
|
||||
.map(transform, remove_columns=dataset.column_names, batched=True)
|
||||
.map(
|
||||
process,
|
||||
fn_kwargs={'tokenizer': tokenizer, 'max_output_length': max_output_length},
|
||||
remove_columns=['human', 'gpt'],
|
||||
batched=True,
|
||||
)
|
||||
.filter(filter_length, batched=True)
|
||||
.shuffle(seed=42)
|
||||
.to_list()[:num_requests]
|
||||
)
|
||||
@@ -1,60 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from hypothesis import strategies as st
|
||||
|
||||
import openllm
|
||||
from openllm_core._configuration import ModelSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@st.composite
|
||||
def model_settings(draw: st.DrawFn):
|
||||
"""Strategy for generating ModelSettings objects."""
|
||||
kwargs: dict[str, t.Any] = {
|
||||
'default_id': st.text(min_size=1),
|
||||
'model_ids': st.lists(st.text(), min_size=1),
|
||||
'architecture': st.text(min_size=1),
|
||||
'url': st.text(),
|
||||
'trust_remote_code': st.booleans(),
|
||||
'requirements': st.none() | st.lists(st.text(), min_size=1),
|
||||
'model_type': st.sampled_from(['causal_lm', 'seq2seq_lm']),
|
||||
'name_type': st.sampled_from(['dasherize', 'lowercase']),
|
||||
'timeout': st.integers(min_value=3600),
|
||||
'workers_per_resource': st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)),
|
||||
}
|
||||
return draw(st.builds(ModelSettings, **kwargs))
|
||||
|
||||
|
||||
def make_llm_config(
|
||||
cls_name: str,
|
||||
dunder_config: dict[str, t.Any] | ModelSettings,
|
||||
fields: tuple[tuple[t.LiteralString, str, t.Any], ...] | None = None,
|
||||
generation_fields: tuple[tuple[t.LiteralString, t.Any], ...] | None = None,
|
||||
) -> type[openllm.LLMConfig]:
|
||||
globs: dict[str, t.Any] = {'openllm': openllm}
|
||||
_config_args: list[str] = []
|
||||
lines: list[str] = [f'class {cls_name}Config(openllm.LLMConfig):']
|
||||
for attr, value in dunder_config.items():
|
||||
_config_args.append(f'"{attr}": __attr_{attr}')
|
||||
globs[f'_{cls_name}Config__attr_{attr}'] = value
|
||||
lines.append(f' __config__ = {{ {", ".join(_config_args)} }}')
|
||||
if fields is not None:
|
||||
for field, type_, default in fields:
|
||||
lines.append(f' {field}: {type_} = openllm.LLMConfig.Field({default!r})')
|
||||
if generation_fields is not None:
|
||||
generation_lines = ['class GenerationConfig:']
|
||||
for field, default in generation_fields:
|
||||
generation_lines.append(f' {field} = {default!r}')
|
||||
lines.extend((' ' + line for line in generation_lines))
|
||||
|
||||
script = '\n'.join(lines)
|
||||
|
||||
if openllm.utils.DEBUG:
|
||||
logger.info('Generated class %s:\n%s', cls_name, script)
|
||||
|
||||
eval(compile(script, 'name', 'exec'), globs)
|
||||
|
||||
return globs[f'{cls_name}Config']
|
||||
@@ -1,152 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import os
|
||||
import typing as t
|
||||
from unittest import mock
|
||||
|
||||
import attr
|
||||
import pytest
|
||||
from hypothesis import assume, given, strategies as st
|
||||
|
||||
import openllm
|
||||
from openllm_core._configuration import GenerationConfig, ModelSettings, field_env_key
|
||||
|
||||
from ._strategies._configuration import make_llm_config, model_settings
|
||||
|
||||
|
||||
def test_forbidden_access():
|
||||
cl_ = make_llm_config(
|
||||
'ForbiddenAccess',
|
||||
{
|
||||
'default_id': 'huggingface/t5-tiny-testing',
|
||||
'model_ids': ['huggingface/t5-tiny-testing', 'bentoml/t5-tiny-testing'],
|
||||
'architecture': 'PreTrainedModel',
|
||||
'requirements': ['bentoml'],
|
||||
},
|
||||
)
|
||||
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), '__config__')
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), 'GenerationConfig')
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), 'SamplingParams')
|
||||
assert openllm.utils.lenient_issubclass(cl_.__openllm_generation_class__, GenerationConfig)
|
||||
|
||||
|
||||
@given(model_settings())
|
||||
def test_class_normal_gen(gen_settings: ModelSettings):
|
||||
assume(gen_settings['default_id'] and all(i for i in gen_settings['model_ids']))
|
||||
cl_: type[openllm.LLMConfig] = make_llm_config('NotFullLLM', gen_settings)
|
||||
assert issubclass(cl_, openllm.LLMConfig)
|
||||
for key in gen_settings:
|
||||
assert object.__getattribute__(cl_, f'__openllm_{key}__') == gen_settings.__getitem__(key)
|
||||
|
||||
|
||||
@given(model_settings(), st.integers())
|
||||
def test_simple_struct_dump(gen_settings: ModelSettings, field1: int):
|
||||
cl_ = make_llm_config('IdempotentLLM', gen_settings, fields=(('field1', 'float', field1),))
|
||||
assert cl_().model_dump()['field1'] == field1
|
||||
|
||||
|
||||
@given(model_settings(), st.integers())
|
||||
def test_config_derivation(gen_settings: ModelSettings, field1: int):
|
||||
cl_ = make_llm_config('IdempotentLLM', gen_settings, fields=(('field1', 'float', field1),))
|
||||
new_cls = cl_.model_derivate('DerivedLLM', default_id='asdfasdf')
|
||||
assert new_cls.__openllm_default_id__ == 'asdfasdf'
|
||||
|
||||
|
||||
@given(model_settings())
|
||||
def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings):
|
||||
cl_ = make_llm_config('AttrsProtocolLLM', gen_settings)
|
||||
assert attr.has(cl_)
|
||||
|
||||
|
||||
@given(
|
||||
model_settings(),
|
||||
st.integers(max_value=283473),
|
||||
st.floats(min_value=0.0, max_value=1.0),
|
||||
st.integers(max_value=283473),
|
||||
st.floats(min_value=0.0, max_value=1.0),
|
||||
)
|
||||
def test_complex_struct_dump(
|
||||
gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float
|
||||
):
|
||||
cl_ = make_llm_config(
|
||||
'ComplexLLM',
|
||||
gen_settings,
|
||||
fields=(('field1', 'float', field1),),
|
||||
generation_fields=(('temperature', temperature),),
|
||||
)
|
||||
sent = cl_()
|
||||
assert sent.model_dump()['field1'] == field1
|
||||
assert sent.model_dump()['generation_config']['temperature'] == temperature
|
||||
assert sent.model_dump(flatten=True)['field1'] == field1
|
||||
assert sent.model_dump(flatten=True)['temperature'] == temperature
|
||||
|
||||
passed = cl_(field1=input_field1, temperature=input_temperature)
|
||||
assert passed.model_dump()['field1'] == input_field1
|
||||
assert passed.model_dump()['generation_config']['temperature'] == input_temperature
|
||||
assert passed.model_dump(flatten=True)['field1'] == input_field1
|
||||
assert passed.model_dump(flatten=True)['temperature'] == input_temperature
|
||||
|
||||
pas_nested = cl_(generation_config={'temperature': input_temperature}, field1=input_field1)
|
||||
assert pas_nested.model_dump()['field1'] == input_field1
|
||||
assert pas_nested.model_dump()['generation_config']['temperature'] == input_temperature
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_env(**attrs: t.Any):
|
||||
with mock.patch.dict(os.environ, attrs, clear=True):
|
||||
yield
|
||||
|
||||
|
||||
def test_struct_envvar():
|
||||
with patch_env(**{field_env_key('field1'): '4', field_env_key('temperature', suffix='generation'): '0.2'}):
|
||||
|
||||
class EnvLLM(openllm.LLMConfig):
|
||||
__config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'}
|
||||
field1: int = 2
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.8
|
||||
|
||||
sent = EnvLLM.model_construct_env()
|
||||
assert sent.field1 == 4
|
||||
assert sent['temperature'] == 0.2
|
||||
|
||||
overwrite_default = EnvLLM()
|
||||
assert overwrite_default.field1 == 4
|
||||
assert overwrite_default['temperature'] == 0.2
|
||||
|
||||
|
||||
def test_struct_provided_fields():
|
||||
class EnvLLM(openllm.LLMConfig):
|
||||
__config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'}
|
||||
field1: int = 2
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.8
|
||||
|
||||
sent = EnvLLM.model_construct_env(field1=20, temperature=0.4)
|
||||
assert sent.field1 == 20
|
||||
assert sent.generation_config.temperature == 0.4
|
||||
|
||||
|
||||
def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mk:
|
||||
mk.setenv(field_env_key('field1'), str(4.0))
|
||||
mk.setenv(field_env_key('temperature', suffix='generation'), str(0.2))
|
||||
sent = make_llm_config(
|
||||
'OverwriteWithEnvAvailable',
|
||||
{'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'},
|
||||
fields=(('field1', 'float', 3.0),),
|
||||
).model_construct_env(field1=20.0, temperature=0.4)
|
||||
assert sent.generation_config.temperature == 0.4
|
||||
assert sent.field1 == 20.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', openllm.CONFIG_MAPPING.keys())
|
||||
def test_configuration_dict_protocol(model_name: str):
|
||||
config = openllm.AutoConfig.for_model(model_name)
|
||||
assert isinstance(config.items(), list)
|
||||
assert isinstance(config.keys(), list)
|
||||
assert isinstance(config.values(), list)
|
||||
assert isinstance(dict(config), dict)
|
||||
@@ -1,42 +1,16 @@
|
||||
from __future__ import annotations
|
||||
import itertools
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
import pytest
|
||||
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralBackend
|
||||
|
||||
_MODELING_MAPPING = {
|
||||
'flan_t5': 'google/flan-t5-small',
|
||||
'opt': 'facebook/opt-125m',
|
||||
'baichuan': 'baichuan-inc/Baichuan-7B',
|
||||
}
|
||||
_PROMPT_MAPPING = {
|
||||
'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?'
|
||||
}
|
||||
import pytest, typing as t
|
||||
|
||||
|
||||
def parametrise_local_llm(model: str) -> t.Generator[tuple[str, openllm.LLM[t.Any, t.Any]], None, None]:
|
||||
if model not in _MODELING_MAPPING:
|
||||
pytest.skip(f"'{model}' is not yet supported in framework testing.")
|
||||
backends: tuple[LiteralBackend, ...] = ('pt',)
|
||||
for backend, prompt in itertools.product(backends, _PROMPT_MAPPING.keys()):
|
||||
yield prompt, openllm.LLM(_MODELING_MAPPING[model], backend=backend)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
|
||||
if os.getenv('GITHUB_ACTIONS') is None:
|
||||
if 'prompt' in metafunc.fixturenames and 'llm' in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
'prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])]
|
||||
)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
|
||||
# If no tests are collected, pytest exists with code 5, which makes the CI fail.
|
||||
if exitstatus == 5:
|
||||
session.exitstatus = 0
|
||||
@pytest.fixture(
|
||||
scope='function',
|
||||
name='model_id',
|
||||
params={
|
||||
'meta-llama/Meta-Llama-3-8B-Instruct',
|
||||
'casperhansen/llama-3-70b-instruct-awq',
|
||||
'TheBloke/Nous-Hermes-2-Mixtral-8x7B-DPO-AWQ',
|
||||
},
|
||||
)
|
||||
def fixture_model_id(request) -> t.Generator[str, None, None]:
|
||||
yield request.param
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
import pytest
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI')
|
||||
def test_flan_t5_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
|
||||
assert llm.generate(prompt)
|
||||
|
||||
assert llm.generate(prompt, temperature=0.8, top_p=0.23)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI')
|
||||
def test_opt_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
|
||||
assert llm.generate(prompt)
|
||||
|
||||
assert llm.generate(prompt, temperature=0.9, top_k=8)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI')
|
||||
def test_baichuan_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
|
||||
assert llm.generate(prompt)
|
||||
|
||||
assert llm.generate(prompt, temperature=0.95)
|
||||
@@ -1,60 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
import pytest
|
||||
|
||||
import openllm
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
HF_INTERNAL_T5_TESTING = 'hf-internal-testing/tiny-random-t5'
|
||||
|
||||
actions_xfail = functools.partial(
|
||||
pytest.mark.xfail,
|
||||
condition=os.getenv('GITHUB_ACTIONS') is not None,
|
||||
reason='Marking GitHub Actions to xfail due to flakiness and building environment not isolated.',
|
||||
)
|
||||
|
||||
|
||||
@actions_xfail
|
||||
def test_general_build_with_internal_testing():
|
||||
bento_store = BentoMLContainer.bento_store.get()
|
||||
|
||||
llm = openllm.LLM(model_id=HF_INTERNAL_T5_TESTING, serialisation='legacy')
|
||||
bento = openllm.build('flan-t5', model_id=HF_INTERNAL_T5_TESTING)
|
||||
|
||||
assert llm.llm_type == bento.info.labels['_type']
|
||||
assert llm.__llm_backend__ == bento.info.labels['_framework']
|
||||
|
||||
bento = openllm.build('flan-t5', model_id=HF_INTERNAL_T5_TESTING)
|
||||
assert len(bento_store.list(bento.tag)) == 1
|
||||
|
||||
|
||||
@actions_xfail
|
||||
def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory):
|
||||
local_path = tmp_path_factory.mktemp('local_t5')
|
||||
llm = openllm.LLM(model_id=HF_INTERNAL_T5_TESTING, serialisation='legacy')
|
||||
|
||||
llm.model.save_pretrained(str(local_path))
|
||||
llm._tokenizer.save_pretrained(str(local_path))
|
||||
|
||||
assert openllm.build('flan-t5', model_id=local_path.resolve().__fspath__(), model_version='local')
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dockerfile_template(tmp_path_factory: pytest.TempPathFactory):
|
||||
file = tmp_path_factory.mktemp('dockerfiles') / 'Dockerfile.template'
|
||||
file.write_text(
|
||||
"{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}"
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('dockerfile_template')
|
||||
@actions_xfail
|
||||
def test_build_with_custom_dockerfile(dockerfile_template: Path):
|
||||
assert openllm.build('flan-t5', model_id=HF_INTERNAL_T5_TESTING, dockerfile_template=str(dockerfile_template))
|
||||
56
openllm-python/tests/regression_test.py
Normal file
56
openllm-python/tests/regression_test.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest, subprocess, sys, openllm, bentoml, asyncio
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||
|
||||
SERVER_PORT = 53822
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_compatible(model_id: str):
|
||||
server = subprocess.Popen([sys.executable, '-m', 'openllm', 'start', model_id, '--port', str(SERVER_PORT)])
|
||||
await asyncio.sleep(5)
|
||||
with bentoml.SyncHTTPClient(f'http://127.0.0.1:{SERVER_PORT}', server_ready_timeout=90) as client:
|
||||
assert client.is_ready(30)
|
||||
|
||||
try:
|
||||
client = AsyncOpenAI(api_key='na', base_url=f'http://127.0.0.1:{SERVER_PORT}/v1')
|
||||
serve_model = (await client.models.list()).data[0].id
|
||||
assert serve_model == openllm.utils.normalise_model_name(model_id)
|
||||
streamable = await client.chat.completions.create(
|
||||
model=serve_model,
|
||||
max_tokens=512,
|
||||
stream=False,
|
||||
messages=[
|
||||
ChatCompletionSystemMessageParam(
|
||||
role='system', content='You will be the writing assistant that assume the tone of Ernest Hemmingway.'
|
||||
),
|
||||
ChatCompletionUserMessageParam(
|
||||
role='user', content='Comment on why Camus thinks we should revolt against life absurdity.'
|
||||
),
|
||||
],
|
||||
)
|
||||
assert streamable is not None
|
||||
finally:
|
||||
server.terminate()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_endpoint(model_id: str):
|
||||
server = subprocess.Popen([sys.executable, '-m', 'openllm', 'start', model_id, '--port', str(SERVER_PORT)])
|
||||
await asyncio.sleep(5)
|
||||
|
||||
with bentoml.SyncHTTPClient(f'http://127.0.0.1:{SERVER_PORT}', server_ready_timeout=90) as client:
|
||||
assert client.is_ready(30)
|
||||
|
||||
try:
|
||||
client = openllm.AsyncHTTPClient(f'http://127.0.0.1:{SERVER_PORT}', api_version='v1')
|
||||
assert await client.health()
|
||||
|
||||
response = await client.generate(
|
||||
'Tell me more about Apple as a company', stop='technology', llm_config={'temperature': 0.5, 'top_p': 0.2}
|
||||
)
|
||||
assert response is not None
|
||||
finally:
|
||||
server.terminate()
|
||||
@@ -1,185 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
import pytest
|
||||
|
||||
import bentoml
|
||||
from openllm import _strategies as strategy
|
||||
from openllm._strategies import CascadingResourceStrategy, NvidiaGpuResource, get_resource
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
|
||||
def test_nvidia_gpu_resource_from_env(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '0,1')
|
||||
resource = NvidiaGpuResource.from_system()
|
||||
assert len(resource) == 2
|
||||
assert resource == ['0', '1']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
|
||||
def test_nvidia_gpu_cutoff_minus(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '0,2,-1,1')
|
||||
resource = NvidiaGpuResource.from_system()
|
||||
assert len(resource) == 2
|
||||
assert resource == ['0', '2']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
|
||||
def test_nvidia_gpu_neg_val(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '-1')
|
||||
resource = NvidiaGpuResource.from_system()
|
||||
assert len(resource) == 0
|
||||
assert resource == []
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
|
||||
def test_nvidia_gpu_parse_literal(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', 'GPU-5ebe9f43-ac33420d4628')
|
||||
resource = NvidiaGpuResource.from_system()
|
||||
assert len(resource) == 1
|
||||
assert resource == ['GPU-5ebe9f43-ac33420d4628']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', 'GPU-5ebe9f43,GPU-ac33420d4628')
|
||||
resource = NvidiaGpuResource.from_system()
|
||||
assert len(resource) == 2
|
||||
assert resource == ['GPU-5ebe9f43', 'GPU-ac33420d4628']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', 'GPU-5ebe9f43,-1,GPU-ac33420d4628')
|
||||
resource = NvidiaGpuResource.from_system()
|
||||
assert len(resource) == 1
|
||||
assert resource == ['GPU-5ebe9f43']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', 'MIG-GPU-5ebe9f43-ac33420d4628')
|
||||
resource = NvidiaGpuResource.from_system()
|
||||
assert len(resource) == 1
|
||||
assert resource == ['MIG-GPU-5ebe9f43-ac33420d4628']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='skip GPUs test on CI')
|
||||
def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
# to make this tests works with system that has GPU
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '')
|
||||
assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests
|
||||
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1]).match(
|
||||
'Input list should be all string type.'
|
||||
)
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match('Input list should be all string type.')
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match(
|
||||
'Failed to parse available GPUs UUID'
|
||||
)
|
||||
|
||||
|
||||
def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
# to make this tests works with system that has GPU
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '')
|
||||
assert NvidiaGpuResource.from_spec(1) == ['0']
|
||||
assert NvidiaGpuResource.from_spec('5') == ['0', '1', '2', '3', '4']
|
||||
assert NvidiaGpuResource.from_spec(1) == ['0']
|
||||
assert NvidiaGpuResource.from_spec(2) == ['0', '1']
|
||||
assert NvidiaGpuResource.from_spec('3') == ['0', '1', '2']
|
||||
assert NvidiaGpuResource.from_spec([1, 3]) == ['1', '3']
|
||||
assert NvidiaGpuResource.from_spec(['1', '3']) == ['1', '3']
|
||||
assert NvidiaGpuResource.from_spec(-1) == []
|
||||
assert NvidiaGpuResource.from_spec('-1') == []
|
||||
assert NvidiaGpuResource.from_spec('') == []
|
||||
assert NvidiaGpuResource.from_spec('-2') == []
|
||||
assert NvidiaGpuResource.from_spec('GPU-288347ab') == ['GPU-288347ab']
|
||||
assert NvidiaGpuResource.from_spec('GPU-288347ab,-1,GPU-ac33420d4628') == ['GPU-288347ab']
|
||||
assert NvidiaGpuResource.from_spec('GPU-288347ab,GPU-ac33420d4628') == ['GPU-288347ab', 'GPU-ac33420d4628']
|
||||
assert NvidiaGpuResource.from_spec('MIG-GPU-288347ab') == ['MIG-GPU-288347ab']
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
NvidiaGpuResource.from_spec((1, 2, 3))
|
||||
with pytest.raises(TypeError):
|
||||
NvidiaGpuResource.from_spec(1.5)
|
||||
with pytest.raises(ValueError):
|
||||
assert NvidiaGpuResource.from_spec(-2)
|
||||
|
||||
|
||||
class GPURunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu')
|
||||
|
||||
|
||||
def unvalidated_get_resource(x: dict[str, t.Any], y: str, validate: bool = False):
|
||||
return get_resource(x, y, validate=validate)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('gpu_type', ['nvidia.com/gpu', 'amd.com/gpu'])
|
||||
def test_cascade_strategy_worker_count(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
monkeypatch.setattr(strategy, 'get_resource', unvalidated_get_resource)
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: 2}, 1) == 1
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 1) == 1
|
||||
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 0.5) == 1
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 9]}, 0.5) == 1
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5) == 1
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 5, 7, 8, 9]}, 0.4) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('gpu_type', ['nvidia.com/gpu', 'amd.com/gpu'])
|
||||
def test_cascade_strategy_worker_env(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
monkeypatch.setattr(strategy, 'get_resource', unvalidated_get_resource)
|
||||
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 0)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '0'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 1)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '1'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 1, 1)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '7'
|
||||
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 0)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '0'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 1)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '0'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 2)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '1'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 2, 1)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '2'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 2, 2)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '7'
|
||||
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 0.5, 0)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '2,7'
|
||||
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5, 0)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '2,7'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5, 1)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '8,9'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.25, 0)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '2,7,8,9'
|
||||
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 0)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '2,6'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 1)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '7,8'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 2)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '9'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('gpu_type', ['nvidia.com/gpu', 'amd.com/gpu'])
|
||||
def test_cascade_strategy_disabled_via_env(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
monkeypatch.setattr(strategy, 'get_resource', unvalidated_get_resource)
|
||||
|
||||
monkeypatch.setenv('CUDA_VISIBLE_DEVICES', '')
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 0)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == ''
|
||||
monkeypatch.delenv('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
monkeypatch.setenv('CUDA_VISIBLE_DEVICES', '-1')
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 1)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '-1'
|
||||
monkeypatch.delenv('CUDA_VISIBLE_DEVICES')
|
||||
@@ -157,7 +157,7 @@ toplevel = ["openllm"]
|
||||
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = ["-rfEX", "-pno:warnings", "--snapshot-warn-unused"]
|
||||
addopts = ["-rfEX", "-pno:warnings"]
|
||||
python_files = ["test_*.py", "*_test.py"]
|
||||
testpaths = ["openllm-python/tests"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user