mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-05-06 23:02:43 -04:00
chore(style): enable yapf to match with style guidelines
Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
import os
|
||||
|
||||
from hypothesis import HealthCheck, settings
|
||||
|
||||
settings.register_profile("CI", settings(suppress_health_check=[HealthCheck.too_slow]), deadline=None)
|
||||
|
||||
if "CI" in os.environ: settings.load_profile("CI")
|
||||
|
||||
@@ -2,22 +2,28 @@ from __future__ import annotations
|
||||
import logging, typing as t, openllm
|
||||
from openllm_core._configuration import ModelSettings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
env_strats = st.sampled_from([openllm.utils.EnvVarMixin(model_name) for model_name in openllm.CONFIG_MAPPING.keys()])
|
||||
|
||||
@st.composite
|
||||
def model_settings(draw: st.DrawFn):
|
||||
"""Strategy for generating ModelSettings objects."""
|
||||
kwargs: dict[str, t.Any] = {
|
||||
"default_id": st.text(min_size=1), "model_ids": st.lists(st.text(), min_size=1), "architecture": st.text(min_size=1), "url": st.text(), "requires_gpu": st.booleans(), "trust_remote_code": st.booleans(), "requirements": st.none()
|
||||
| st.lists(st.text(), min_size=1), "default_implementation": st.dictionaries(st.sampled_from(["cpu", "nvidia.com/gpu"]), st.sampled_from(["vllm", "pt", "tf", "flax"])), "model_type": st.sampled_from(["causal_lm", "seq2seq_lm"]), "runtime": st.sampled_from(["transformers", "ggml"]), "name_type": st.sampled_from(["dasherize", "lowercase"]), "timeout": st.integers(
|
||||
min_value=3600
|
||||
), "workers_per_resource": st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)),
|
||||
"default_id": st.text(min_size=1),
|
||||
"model_ids": st.lists(st.text(), min_size=1),
|
||||
"architecture": st.text(min_size=1),
|
||||
"url": st.text(),
|
||||
"requires_gpu": st.booleans(),
|
||||
"trust_remote_code": st.booleans(),
|
||||
"requirements": st.none() | st.lists(st.text(), min_size=1),
|
||||
"default_implementation": st.dictionaries(st.sampled_from(["cpu", "nvidia.com/gpu"]), st.sampled_from(["vllm", "pt", "tf", "flax"])),
|
||||
"model_type": st.sampled_from(["causal_lm", "seq2seq_lm"]),
|
||||
"runtime": st.sampled_from(["transformers", "ggml"]),
|
||||
"name_type": st.sampled_from(["dasherize", "lowercase"]),
|
||||
"timeout": st.integers(min_value=3600),
|
||||
"workers_per_resource": st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)),
|
||||
}
|
||||
return draw(st.builds(ModelSettings, **kwargs))
|
||||
|
||||
def make_llm_config(cls_name: str, dunder_config: dict[str, t.Any] | ModelSettings, fields: tuple[tuple[t.LiteralString, str, t.Any], ...] | None = None, generation_fields: tuple[tuple[t.LiteralString, t.Any], ...] | None = None,) -> type[openllm.LLMConfig]:
|
||||
globs: dict[str, t.Any] = {"openllm": openllm}
|
||||
_config_args: list[str] = []
|
||||
|
||||
@@ -4,7 +4,6 @@ from unittest import mock
|
||||
from openllm_core._configuration import GenerationConfig, ModelSettings, field_env_key
|
||||
from hypothesis import assume, given, strategies as st
|
||||
from ._strategies._configuration import make_llm_config, model_settings
|
||||
|
||||
# XXX: @aarnphm fixes TypedDict behaviour in 3.11
|
||||
@pytest.mark.skipif(sys.version_info[:2] == (3, 11), reason="TypedDict in 3.11 behaves differently, so we need to fix this")
|
||||
def test_missing_default():
|
||||
@@ -14,7 +13,6 @@ def test_missing_default():
|
||||
make_llm_config("MissingModelId", {"default_id": "huggingface/t5-tiny-testing", "requirements": ["bentoml"]})
|
||||
with pytest.raises(ValueError, match="Missing required fields *"):
|
||||
make_llm_config("MissingArchitecture", {"default_id": "huggingface/t5-tiny-testing", "model_ids": ["huggingface/t5-tiny-testing"], "requirements": ["bentoml"],},)
|
||||
|
||||
def test_forbidden_access():
|
||||
cl_ = make_llm_config("ForbiddenAccess", {"default_id": "huggingface/t5-tiny-testing", "model_ids": ["huggingface/t5-tiny-testing", "bentoml/t5-tiny-testing"], "architecture": "PreTrainedModel", "requirements": ["bentoml"],},)
|
||||
|
||||
@@ -22,7 +20,6 @@ def test_forbidden_access():
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), "GenerationConfig",)
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), "SamplingParams",)
|
||||
assert openllm.utils.lenient_issubclass(cl_.__openllm_generation_class__, GenerationConfig)
|
||||
|
||||
@given(model_settings())
|
||||
def test_class_normal_gen(gen_settings: ModelSettings):
|
||||
assume(gen_settings["default_id"] and all(i for i in gen_settings["model_ids"]))
|
||||
@@ -30,23 +27,19 @@ def test_class_normal_gen(gen_settings: ModelSettings):
|
||||
assert issubclass(cl_, openllm.LLMConfig)
|
||||
for key in gen_settings:
|
||||
assert object.__getattribute__(cl_, f"__openllm_{key}__") == gen_settings.__getitem__(key)
|
||||
|
||||
@given(model_settings(), st.integers())
|
||||
def test_simple_struct_dump(gen_settings: ModelSettings, field1: int):
|
||||
cl_ = make_llm_config("IdempotentLLM", gen_settings, fields=(("field1", "float", field1),))
|
||||
assert cl_().model_dump()["field1"] == field1
|
||||
|
||||
@given(model_settings(), st.integers())
|
||||
def test_config_derivation(gen_settings: ModelSettings, field1: int):
|
||||
cl_ = make_llm_config("IdempotentLLM", gen_settings, fields=(("field1", "float", field1),))
|
||||
new_cls = cl_.model_derivate("DerivedLLM", default_id="asdfasdf")
|
||||
assert new_cls.__openllm_default_id__ == "asdfasdf"
|
||||
|
||||
@given(model_settings())
|
||||
def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings):
|
||||
cl_ = make_llm_config("AttrsProtocolLLM", gen_settings)
|
||||
assert attr.has(cl_)
|
||||
|
||||
@given(model_settings(), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0),)
|
||||
def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float):
|
||||
cl_ = make_llm_config("ComplexLLM", gen_settings, fields=(("field1", "float", field1),), generation_fields=(("temperature", temperature),),)
|
||||
@@ -65,12 +58,10 @@ def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperatu
|
||||
pas_nested = cl_(generation_config={"temperature": input_temperature}, field1=input_field1)
|
||||
assert pas_nested.model_dump()["field1"] == input_field1
|
||||
assert pas_nested.model_dump()["generation_config"]["temperature"] == input_temperature
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_env(**attrs: t.Any):
|
||||
with mock.patch.dict(os.environ, attrs, clear=True):
|
||||
yield
|
||||
|
||||
def test_struct_envvar():
|
||||
with patch_env(**{field_env_key("env_llm", "field1"): "4", field_env_key("env_llm", "temperature", suffix="generation"): "0.2",}):
|
||||
|
||||
@@ -88,7 +79,6 @@ def test_struct_envvar():
|
||||
overwrite_default = EnvLLM()
|
||||
assert overwrite_default.field1 == 4
|
||||
assert overwrite_default["temperature"] == 0.2
|
||||
|
||||
def test_struct_provided_fields():
|
||||
class EnvLLM(openllm.LLMConfig):
|
||||
__config__ = {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"], "architecture": "PreTrainedModel",}
|
||||
@@ -100,7 +90,6 @@ def test_struct_provided_fields():
|
||||
sent = EnvLLM.model_construct_env(field1=20, temperature=0.4)
|
||||
assert sent.field1 == 20
|
||||
assert sent.generation_config.temperature == 0.4
|
||||
|
||||
def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mk:
|
||||
mk.setenv(field_env_key("overwrite_with_env_available", "field1"), str(4.0))
|
||||
@@ -108,13 +97,11 @@ def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPat
|
||||
sent = make_llm_config("OverwriteWithEnvAvailable", {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"], "architecture": "PreTrainedModel"}, fields=(("field1", "float", 3.0),),).model_construct_env(field1=20.0, temperature=0.4)
|
||||
assert sent.generation_config.temperature == 0.4
|
||||
assert sent.field1 == 20.0
|
||||
|
||||
@given(model_settings())
|
||||
@pytest.mark.parametrize(("return_dict", "typ"), [(True, dict), (False, transformers.GenerationConfig)])
|
||||
def test_conversion_to_transformers(return_dict: bool, typ: type[t.Any], gen_settings: ModelSettings):
|
||||
cl_ = make_llm_config("ConversionLLM", gen_settings)
|
||||
assert isinstance(cl_().to_generation_config(return_as_dict=return_dict), typ)
|
||||
|
||||
@given(model_settings())
|
||||
def test_click_conversion(gen_settings: ModelSettings):
|
||||
# currently our conversion omit Union type.
|
||||
@@ -126,7 +113,6 @@ def test_click_conversion(gen_settings: ModelSettings):
|
||||
filtered = {k for k, v in cl_.__openllm_hints__.items() if t.get_origin(v) is not t.Union}
|
||||
click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith("fake_")]
|
||||
assert len(filtered) == len(click_options_filtered)
|
||||
|
||||
@pytest.mark.parametrize("model_name", openllm.CONFIG_MAPPING.keys())
|
||||
def test_configuration_dict_protocol(model_name: str):
|
||||
config = openllm.AutoConfig.for_model(model_name)
|
||||
|
||||
@@ -13,7 +13,6 @@ def parametrise_local_llm(model: str,) -> t.Generator[tuple[str, openllm.LLMRunn
|
||||
for framework, prompt in itertools.product(runtime_impl, _PROMPT_MAPPING.keys()):
|
||||
llm = openllm.Runner(model, model_id=_FRAMEWORK_MAPPING[model], ensure_available=True, implementation=framework, init_local=True,)
|
||||
yield prompt, llm
|
||||
|
||||
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
|
||||
if os.getenv("GITHUB_ACTIONS") is None:
|
||||
if "prompt" in metafunc.fixturenames and "llm" in metafunc.fixturenames:
|
||||
|
||||
@@ -5,7 +5,6 @@ import attr, docker, docker.errors, docker.types, orjson, pytest, openllm
|
||||
from syrupy.extensions.json import JSONSnapshotExtension
|
||||
from openllm._llm import normalise_model_name
|
||||
from openllm_core._typing_compat import DictStrAny, ListAny
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -14,7 +13,6 @@ if t.TYPE_CHECKING:
|
||||
from syrupy.types import PropertyFilter, PropertyMatcher, SerializableData, SerializedData
|
||||
from openllm._configuration import GenerationConfig
|
||||
from openllm.client import BaseAsyncClient
|
||||
|
||||
class ResponseComparator(JSONSnapshotExtension):
|
||||
def serialize(self, data: SerializableData, *, exclude: PropertyFilter | None = None, matcher: PropertyMatcher | None = None,) -> SerializedData:
|
||||
if openllm.utils.LazyType(ListAny).isinstance(data):
|
||||
@@ -52,11 +50,9 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and eq_config(s.marshaled_config, t.marshaled_config))
|
||||
|
||||
return len(serialized_data) == len(snapshot_data) and all([eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])
|
||||
|
||||
@pytest.fixture()
|
||||
def response_snapshot(snapshot: SnapshotAssertion):
|
||||
return snapshot.use_extension(ResponseComparator)
|
||||
|
||||
@attr.define(init=False)
|
||||
class _Handle(ABC):
|
||||
port: int
|
||||
@@ -88,7 +84,6 @@ class _Handle(ABC):
|
||||
except Exception:
|
||||
time.sleep(1)
|
||||
raise RuntimeError(f"Handle failed to initialise within {timeout} seconds.")
|
||||
|
||||
@attr.define(init=False)
|
||||
class LocalHandle(_Handle):
|
||||
process: subprocess.Popen[bytes]
|
||||
@@ -98,12 +93,10 @@ class LocalHandle(_Handle):
|
||||
|
||||
def status(self) -> bool:
|
||||
return self.process.poll() is None
|
||||
|
||||
class HandleProtocol(t.Protocol):
|
||||
@contextlib.contextmanager
|
||||
def __call__(*, model: str, model_id: str, image_tag: str, quantize: t.AnyStr | None = None,) -> t.Generator[_Handle, None, None]:
|
||||
...
|
||||
|
||||
@attr.define(init=False)
|
||||
class DockerHandle(_Handle):
|
||||
container_name: str
|
||||
@@ -115,7 +108,6 @@ class DockerHandle(_Handle):
|
||||
def status(self) -> bool:
|
||||
container = self.docker_client.containers.get(self.container_name)
|
||||
return container.status in ["running", "created"]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _local_handle(model: str, model_id: str, image_tag: str, deployment_mode: t.Literal["container", "local"], quantize: t.Literal["int8", "int4", "gptq"] | None = None, *, _serve_grpc: bool = False,):
|
||||
with openllm.utils.reserve_free_port() as port:
|
||||
@@ -136,7 +128,6 @@ def _local_handle(model: str, model_id: str, image_tag: str, deployment_mode: t.
|
||||
proc.stdout.close()
|
||||
if proc.stderr:
|
||||
proc.stderr.close()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _container_handle(model: str, model_id: str, image_tag: str, deployment_mode: t.Literal["container", "local"], quantize: t.Literal["int8", "int4", "gptq"] | None = None, *, _serve_grpc: bool = False,):
|
||||
envvar = openllm.utils.EnvVarMixin(model)
|
||||
@@ -177,23 +168,19 @@ def _container_handle(model: str, model_id: str, image_tag: str, deployment_mode
|
||||
print(container_output, file=sys.stderr)
|
||||
|
||||
container.remove()
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def clean_context() -> t.Generator[contextlib.ExitStack, None, None]:
|
||||
stack = contextlib.ExitStack()
|
||||
yield stack
|
||||
stack.close()
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def el() -> t.Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
loop = asyncio.get_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
@pytest.fixture(params=["container", "local"], scope="session")
|
||||
def deployment_mode(request: pytest.FixtureRequest) -> str:
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal["container", "local"]):
|
||||
if deployment_mode == "container":
|
||||
|
||||
@@ -9,21 +9,17 @@ if t.TYPE_CHECKING:
|
||||
import contextlib
|
||||
|
||||
from .conftest import HandleProtocol, ResponseComparator, _Handle
|
||||
|
||||
model = "flan_t5"
|
||||
model_id = "google/flan-t5-small"
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal["container", "local"], clean_context: contextlib.ExitStack,):
|
||||
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
|
||||
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
|
||||
yield handle
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flan_t5(flan_t5_handle: _Handle):
|
||||
await flan_t5_handle.health(240)
|
||||
return flan_t5_handle.client
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_flan_t5(flan_t5: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
|
||||
client = await flan_t5
|
||||
|
||||
@@ -9,21 +9,17 @@ if t.TYPE_CHECKING:
|
||||
import contextlib
|
||||
|
||||
from .conftest import HandleProtocol, ResponseComparator, _Handle
|
||||
|
||||
model = "opt"
|
||||
model_id = "facebook/opt-125m"
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal["container", "local"], clean_context: contextlib.ExitStack,):
|
||||
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
|
||||
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
|
||||
yield handle
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def opt_125m(opt_125m_handle: _Handle):
|
||||
await opt_125m_handle.health(240)
|
||||
return opt_125m_handle.client
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_opt_125m(opt_125m: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
|
||||
client = await opt_125m
|
||||
|
||||
@@ -2,19 +2,16 @@ from __future__ import annotations
|
||||
import os, typing as t, pytest
|
||||
|
||||
if t.TYPE_CHECKING: import openllm
|
||||
|
||||
@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="Model is too large for CI")
|
||||
def test_flan_t5_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
|
||||
assert llm(prompt)
|
||||
|
||||
assert llm(prompt, temperature=0.8, top_p=0.23)
|
||||
|
||||
@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="Model is too large for CI")
|
||||
def test_opt_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
|
||||
assert llm(prompt)
|
||||
|
||||
assert llm(prompt, temperature=0.9, top_k=8)
|
||||
|
||||
@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="Model is too large for CI")
|
||||
def test_baichuan_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
|
||||
assert llm(prompt)
|
||||
|
||||
@@ -6,7 +6,6 @@ if t.TYPE_CHECKING: from pathlib import Path
|
||||
HF_INTERNAL_T5_TESTING = "hf-internal-testing/tiny-random-t5"
|
||||
|
||||
actions_xfail = functools.partial(pytest.mark.xfail, condition=os.getenv("GITHUB_ACTIONS") is not None, reason="Marking GitHub Actions to xfail due to flakiness and building environment not isolated.",)
|
||||
|
||||
@actions_xfail
|
||||
def test_general_build_with_internal_testing():
|
||||
bento_store = BentoMLContainer.bento_store.get()
|
||||
@@ -19,7 +18,6 @@ def test_general_build_with_internal_testing():
|
||||
|
||||
bento = openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING)
|
||||
assert len(bento_store.list(bento.tag)) == 1
|
||||
|
||||
@actions_xfail
|
||||
def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory):
|
||||
local_path = tmp_path_factory.mktemp("local_t5")
|
||||
@@ -31,13 +29,11 @@ def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory):
|
||||
llm.save_pretrained(local_path)
|
||||
|
||||
assert openllm.build("flan-t5", model_id=local_path.resolve().__fspath__(), model_version="local")
|
||||
|
||||
@pytest.fixture()
|
||||
def dockerfile_template(tmp_path_factory: pytest.TempPathFactory):
|
||||
file = tmp_path_factory.mktemp("dockerfiles") / "Dockerfile.template"
|
||||
file.write_text("{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}")
|
||||
return file
|
||||
|
||||
@pytest.mark.usefixtures("dockerfile_template")
|
||||
@actions_xfail
|
||||
def test_build_with_custom_dockerfile(dockerfile_template: Path):
|
||||
|
||||
@@ -3,7 +3,6 @@ import os, typing as t, pytest, bentoml
|
||||
from openllm_core import _strategies as strategy
|
||||
from openllm_core._strategies import CascadingResourceStrategy, NvidiaGpuResource, get_resource
|
||||
if t.TYPE_CHECKING: from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
def test_nvidia_gpu_resource_from_env(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv("CUDA_VISIBLE_DEVICES", "0,1")
|
||||
@@ -11,7 +10,6 @@ def test_nvidia_gpu_resource_from_env(monkeypatch: pytest.MonkeyPatch):
|
||||
assert len(resource) == 2
|
||||
assert resource == ["0", "1"]
|
||||
mcls.delenv("CUDA_VISIBLE_DEVICES")
|
||||
|
||||
def test_nvidia_gpu_cutoff_minus(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv("CUDA_VISIBLE_DEVICES", "0,2,-1,1")
|
||||
@@ -19,7 +17,6 @@ def test_nvidia_gpu_cutoff_minus(monkeypatch: pytest.MonkeyPatch):
|
||||
assert len(resource) == 2
|
||||
assert resource == ["0", "2"]
|
||||
mcls.delenv("CUDA_VISIBLE_DEVICES")
|
||||
|
||||
def test_nvidia_gpu_neg_val(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv("CUDA_VISIBLE_DEVICES", "-1")
|
||||
@@ -27,7 +24,6 @@ def test_nvidia_gpu_neg_val(monkeypatch: pytest.MonkeyPatch):
|
||||
assert len(resource) == 0
|
||||
assert resource == []
|
||||
mcls.delenv("CUDA_VISIBLE_DEVICES")
|
||||
|
||||
def test_nvidia_gpu_parse_literal(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43-ac33420d4628")
|
||||
@@ -53,7 +49,6 @@ def test_nvidia_gpu_parse_literal(monkeypatch: pytest.MonkeyPatch):
|
||||
assert len(resource) == 1
|
||||
assert resource == ["MIG-GPU-5ebe9f43-ac33420d4628"]
|
||||
mcls.delenv("CUDA_VISIBLE_DEVICES")
|
||||
|
||||
@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="skip GPUs test on CI")
|
||||
def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
@@ -64,7 +59,6 @@ def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch):
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1],).match("Input list should be all string type.")
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match("Input list should be all string type.")
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, ["GPU-5ebe9f43", "GPU-ac33420d4628"]).match("Failed to parse available GPUs UUID")
|
||||
|
||||
def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
# to make this tests works with system that has GPU
|
||||
@@ -91,13 +85,10 @@ def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):
|
||||
NvidiaGpuResource.from_spec(1.5)
|
||||
with pytest.raises(ValueError):
|
||||
assert NvidiaGpuResource.from_spec(-2)
|
||||
|
||||
class GPURunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "amd.com/gpu")
|
||||
|
||||
def unvalidated_get_resource(x: dict[str, t.Any], y: str, validate: bool = False):
|
||||
return get_resource(x, y, validate=validate)
|
||||
|
||||
@pytest.mark.parametrize("gpu_type", ["nvidia.com/gpu", "amd.com/gpu"])
|
||||
def test_cascade_strategy_worker_count(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource)
|
||||
@@ -108,7 +99,6 @@ def test_cascade_strategy_worker_count(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 9]}, 0.5) == 1
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5) == 1
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 5, 7, 8, 9]}, 0.4) == 1
|
||||
|
||||
@pytest.mark.parametrize("gpu_type", ["nvidia.com/gpu", "amd.com/gpu"])
|
||||
def test_cascade_strategy_worker_env(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource)
|
||||
@@ -147,7 +137,6 @@ def test_cascade_strategy_worker_env(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "7,8"
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 2)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "9"
|
||||
|
||||
@pytest.mark.parametrize("gpu_type", ["nvidia.com/gpu", "amd.com/gpu"])
|
||||
def test_cascade_strategy_disabled_via_env(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource)
|
||||
|
||||
Reference in New Issue
Block a user