mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-04-22 16:07:24 -04:00
chore(style): synchronized style across packages [skip ci]
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,6 @@ from __future__ import annotations
|
||||
import os
|
||||
|
||||
from hypothesis import HealthCheck, settings
|
||||
settings.register_profile("CI", settings(suppress_health_check=[HealthCheck.too_slow]), deadline=None)
|
||||
settings.register_profile('CI', settings(suppress_health_check=[HealthCheck.too_slow]), deadline=None)
|
||||
|
||||
if "CI" in os.environ: settings.load_profile("CI")
|
||||
if 'CI' in os.environ: settings.load_profile('CI')
|
||||
|
||||
@@ -7,21 +7,21 @@ logger = logging.getLogger(__name__)
|
||||
env_strats = st.sampled_from([openllm.utils.EnvVarMixin(model_name) for model_name in openllm.CONFIG_MAPPING.keys()])
|
||||
@st.composite
|
||||
def model_settings(draw: st.DrawFn):
|
||||
"""Strategy for generating ModelSettings objects."""
|
||||
'''Strategy for generating ModelSettings objects.'''
|
||||
kwargs: dict[str, t.Any] = {
|
||||
"default_id": st.text(min_size=1),
|
||||
"model_ids": st.lists(st.text(), min_size=1),
|
||||
"architecture": st.text(min_size=1),
|
||||
"url": st.text(),
|
||||
"requires_gpu": st.booleans(),
|
||||
"trust_remote_code": st.booleans(),
|
||||
"requirements": st.none() | st.lists(st.text(), min_size=1),
|
||||
"default_implementation": st.dictionaries(st.sampled_from(["cpu", "nvidia.com/gpu"]), st.sampled_from(["vllm", "pt", "tf", "flax"])),
|
||||
"model_type": st.sampled_from(["causal_lm", "seq2seq_lm"]),
|
||||
"runtime": st.sampled_from(["transformers", "ggml"]),
|
||||
"name_type": st.sampled_from(["dasherize", "lowercase"]),
|
||||
"timeout": st.integers(min_value=3600),
|
||||
"workers_per_resource": st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)),
|
||||
'default_id': st.text(min_size=1),
|
||||
'model_ids': st.lists(st.text(), min_size=1),
|
||||
'architecture': st.text(min_size=1),
|
||||
'url': st.text(),
|
||||
'requires_gpu': st.booleans(),
|
||||
'trust_remote_code': st.booleans(),
|
||||
'requirements': st.none() | st.lists(st.text(), min_size=1),
|
||||
'default_implementation': st.dictionaries(st.sampled_from(['cpu', 'nvidia.com/gpu']), st.sampled_from(['vllm', 'pt', 'tf', 'flax'])),
|
||||
'model_type': st.sampled_from(['causal_lm', 'seq2seq_lm']),
|
||||
'runtime': st.sampled_from(['transformers', 'ggml']),
|
||||
'name_type': st.sampled_from(['dasherize', 'lowercase']),
|
||||
'timeout': st.integers(min_value=3600),
|
||||
'workers_per_resource': st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)),
|
||||
}
|
||||
return draw(st.builds(ModelSettings, **kwargs))
|
||||
def make_llm_config(
|
||||
@@ -30,27 +30,27 @@ def make_llm_config(
|
||||
fields: tuple[tuple[t.LiteralString, str, t.Any], ...] | None = None,
|
||||
generation_fields: tuple[tuple[t.LiteralString, t.Any], ...] | None = None,
|
||||
) -> type[openllm.LLMConfig]:
|
||||
globs: dict[str, t.Any] = {"openllm": openllm}
|
||||
globs: dict[str, t.Any] = {'openllm': openllm}
|
||||
_config_args: list[str] = []
|
||||
lines: list[str] = [f"class {cls_name}Config(openllm.LLMConfig):"]
|
||||
lines: list[str] = [f'class {cls_name}Config(openllm.LLMConfig):']
|
||||
for attr, value in dunder_config.items():
|
||||
_config_args.append(f'"{attr}": __attr_{attr}')
|
||||
globs[f"_{cls_name}Config__attr_{attr}"] = value
|
||||
globs[f'_{cls_name}Config__attr_{attr}'] = value
|
||||
lines.append(f' __config__ = {{ {", ".join(_config_args)} }}')
|
||||
if fields is not None:
|
||||
for field, type_, default in fields:
|
||||
lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({default!r})")
|
||||
lines.append(f' {field}: {type_} = openllm.LLMConfig.Field({default!r})')
|
||||
if generation_fields is not None:
|
||||
generation_lines = ["class GenerationConfig:"]
|
||||
generation_lines = ['class GenerationConfig:']
|
||||
for field, default in generation_fields:
|
||||
generation_lines.append(f" {field} = {default!r}")
|
||||
lines.extend((" " + line for line in generation_lines))
|
||||
generation_lines.append(f' {field} = {default!r}')
|
||||
lines.extend((' ' + line for line in generation_lines))
|
||||
|
||||
script = "\n".join(lines)
|
||||
script = '\n'.join(lines)
|
||||
|
||||
if openllm.utils.DEBUG:
|
||||
logger.info("Generated class %s:\n%s", cls_name, script)
|
||||
logger.info('Generated class %s:\n%s', cls_name, script)
|
||||
|
||||
eval(compile(script, "name", "exec"), globs)
|
||||
eval(compile(script, 'name', 'exec'), globs)
|
||||
|
||||
return globs[f"{cls_name}Config"]
|
||||
return globs[f'{cls_name}Config']
|
||||
|
||||
@@ -5,72 +5,72 @@ from openllm_core._configuration import GenerationConfig, ModelSettings, field_e
|
||||
from hypothesis import assume, given, strategies as st
|
||||
from ._strategies._configuration import make_llm_config, model_settings
|
||||
# XXX: @aarnphm fixes TypedDict behaviour in 3.11
|
||||
@pytest.mark.skipif(sys.version_info[:2] == (3, 11), reason="TypedDict in 3.11 behaves differently, so we need to fix this")
|
||||
@pytest.mark.skipif(sys.version_info[:2] == (3, 11), reason='TypedDict in 3.11 behaves differently, so we need to fix this')
|
||||
def test_missing_default():
|
||||
with pytest.raises(ValueError, match="Missing required fields *"):
|
||||
make_llm_config("MissingDefaultId", {"name_type": "lowercase", "requirements": ["bentoml"]})
|
||||
with pytest.raises(ValueError, match="Missing required fields *"):
|
||||
make_llm_config("MissingModelId", {"default_id": "huggingface/t5-tiny-testing", "requirements": ["bentoml"]})
|
||||
with pytest.raises(ValueError, match="Missing required fields *"):
|
||||
make_llm_config("MissingArchitecture", {"default_id": "huggingface/t5-tiny-testing", "model_ids": ["huggingface/t5-tiny-testing"], "requirements": ["bentoml"],},)
|
||||
with pytest.raises(ValueError, match='Missing required fields *'):
|
||||
make_llm_config('MissingDefaultId', {'name_type': 'lowercase', 'requirements': ['bentoml']})
|
||||
with pytest.raises(ValueError, match='Missing required fields *'):
|
||||
make_llm_config('MissingModelId', {'default_id': 'huggingface/t5-tiny-testing', 'requirements': ['bentoml']})
|
||||
with pytest.raises(ValueError, match='Missing required fields *'):
|
||||
make_llm_config('MissingArchitecture', {'default_id': 'huggingface/t5-tiny-testing', 'model_ids': ['huggingface/t5-tiny-testing'], 'requirements': ['bentoml'],},)
|
||||
def test_forbidden_access():
|
||||
cl_ = make_llm_config(
|
||||
"ForbiddenAccess", {
|
||||
"default_id": "huggingface/t5-tiny-testing", "model_ids": ["huggingface/t5-tiny-testing", "bentoml/t5-tiny-testing"], "architecture": "PreTrainedModel", "requirements": ["bentoml"],
|
||||
'ForbiddenAccess', {
|
||||
'default_id': 'huggingface/t5-tiny-testing', 'model_ids': ['huggingface/t5-tiny-testing', 'bentoml/t5-tiny-testing'], 'architecture': 'PreTrainedModel', 'requirements': ['bentoml'],
|
||||
},
|
||||
)
|
||||
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), "__config__",)
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), "GenerationConfig",)
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), "SamplingParams",)
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), '__config__',)
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), 'GenerationConfig',)
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), 'SamplingParams',)
|
||||
assert openllm.utils.lenient_issubclass(cl_.__openllm_generation_class__, GenerationConfig)
|
||||
@given(model_settings())
|
||||
def test_class_normal_gen(gen_settings: ModelSettings):
|
||||
assume(gen_settings["default_id"] and all(i for i in gen_settings["model_ids"]))
|
||||
cl_: type[openllm.LLMConfig] = make_llm_config("NotFullLLM", gen_settings)
|
||||
assume(gen_settings['default_id'] and all(i for i in gen_settings['model_ids']))
|
||||
cl_: type[openllm.LLMConfig] = make_llm_config('NotFullLLM', gen_settings)
|
||||
assert issubclass(cl_, openllm.LLMConfig)
|
||||
for key in gen_settings:
|
||||
assert object.__getattribute__(cl_, f"__openllm_{key}__") == gen_settings.__getitem__(key)
|
||||
assert object.__getattribute__(cl_, f'__openllm_{key}__') == gen_settings.__getitem__(key)
|
||||
@given(model_settings(), st.integers())
|
||||
def test_simple_struct_dump(gen_settings: ModelSettings, field1: int):
|
||||
cl_ = make_llm_config("IdempotentLLM", gen_settings, fields=(("field1", "float", field1),))
|
||||
assert cl_().model_dump()["field1"] == field1
|
||||
cl_ = make_llm_config('IdempotentLLM', gen_settings, fields=(('field1', 'float', field1),))
|
||||
assert cl_().model_dump()['field1'] == field1
|
||||
@given(model_settings(), st.integers())
|
||||
def test_config_derivation(gen_settings: ModelSettings, field1: int):
|
||||
cl_ = make_llm_config("IdempotentLLM", gen_settings, fields=(("field1", "float", field1),))
|
||||
new_cls = cl_.model_derivate("DerivedLLM", default_id="asdfasdf")
|
||||
assert new_cls.__openllm_default_id__ == "asdfasdf"
|
||||
cl_ = make_llm_config('IdempotentLLM', gen_settings, fields=(('field1', 'float', field1),))
|
||||
new_cls = cl_.model_derivate('DerivedLLM', default_id='asdfasdf')
|
||||
assert new_cls.__openllm_default_id__ == 'asdfasdf'
|
||||
@given(model_settings())
|
||||
def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings):
|
||||
cl_ = make_llm_config("AttrsProtocolLLM", gen_settings)
|
||||
cl_ = make_llm_config('AttrsProtocolLLM', gen_settings)
|
||||
assert attr.has(cl_)
|
||||
@given(model_settings(), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0),)
|
||||
def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float):
|
||||
cl_ = make_llm_config("ComplexLLM", gen_settings, fields=(("field1", "float", field1),), generation_fields=(("temperature", temperature),),)
|
||||
cl_ = make_llm_config('ComplexLLM', gen_settings, fields=(('field1', 'float', field1),), generation_fields=(('temperature', temperature),),)
|
||||
sent = cl_()
|
||||
assert sent.model_dump()["field1"] == field1
|
||||
assert sent.model_dump()["generation_config"]["temperature"] == temperature
|
||||
assert sent.model_dump(flatten=True)["field1"] == field1
|
||||
assert sent.model_dump(flatten=True)["temperature"] == temperature
|
||||
assert sent.model_dump()['field1'] == field1
|
||||
assert sent.model_dump()['generation_config']['temperature'] == temperature
|
||||
assert sent.model_dump(flatten=True)['field1'] == field1
|
||||
assert sent.model_dump(flatten=True)['temperature'] == temperature
|
||||
|
||||
passed = cl_(field1=input_field1, temperature=input_temperature)
|
||||
assert passed.model_dump()["field1"] == input_field1
|
||||
assert passed.model_dump()["generation_config"]["temperature"] == input_temperature
|
||||
assert passed.model_dump(flatten=True)["field1"] == input_field1
|
||||
assert passed.model_dump(flatten=True)["temperature"] == input_temperature
|
||||
assert passed.model_dump()['field1'] == input_field1
|
||||
assert passed.model_dump()['generation_config']['temperature'] == input_temperature
|
||||
assert passed.model_dump(flatten=True)['field1'] == input_field1
|
||||
assert passed.model_dump(flatten=True)['temperature'] == input_temperature
|
||||
|
||||
pas_nested = cl_(generation_config={"temperature": input_temperature}, field1=input_field1)
|
||||
assert pas_nested.model_dump()["field1"] == input_field1
|
||||
assert pas_nested.model_dump()["generation_config"]["temperature"] == input_temperature
|
||||
pas_nested = cl_(generation_config={'temperature': input_temperature}, field1=input_field1)
|
||||
assert pas_nested.model_dump()['field1'] == input_field1
|
||||
assert pas_nested.model_dump()['generation_config']['temperature'] == input_temperature
|
||||
@contextlib.contextmanager
|
||||
def patch_env(**attrs: t.Any):
|
||||
with mock.patch.dict(os.environ, attrs, clear=True):
|
||||
yield
|
||||
def test_struct_envvar():
|
||||
with patch_env(**{field_env_key("env_llm", "field1"): "4", field_env_key("env_llm", "temperature", suffix="generation"): "0.2",}):
|
||||
with patch_env(**{field_env_key('env_llm', 'field1'): '4', field_env_key('env_llm', 'temperature', suffix='generation'): '0.2',}):
|
||||
|
||||
class EnvLLM(openllm.LLMConfig):
|
||||
__config__ = {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"], "architecture": "PreTrainedModel",}
|
||||
__config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel',}
|
||||
field1: int = 2
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -78,14 +78,14 @@ def test_struct_envvar():
|
||||
|
||||
sent = EnvLLM.model_construct_env()
|
||||
assert sent.field1 == 4
|
||||
assert sent["temperature"] == 0.2
|
||||
assert sent['temperature'] == 0.2
|
||||
|
||||
overwrite_default = EnvLLM()
|
||||
assert overwrite_default.field1 == 4
|
||||
assert overwrite_default["temperature"] == 0.2
|
||||
assert overwrite_default['temperature'] == 0.2
|
||||
def test_struct_provided_fields():
|
||||
class EnvLLM(openllm.LLMConfig):
|
||||
__config__ = {"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"], "architecture": "PreTrainedModel",}
|
||||
__config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel',}
|
||||
field1: int = 2
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -96,19 +96,19 @@ def test_struct_provided_fields():
|
||||
assert sent.generation_config.temperature == 0.4
|
||||
def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mk:
|
||||
mk.setenv(field_env_key("overwrite_with_env_available", "field1"), str(4.0))
|
||||
mk.setenv(field_env_key("overwrite_with_env_available", "temperature", suffix="generation"), str(0.2))
|
||||
mk.setenv(field_env_key('overwrite_with_env_available', 'field1'), str(4.0))
|
||||
mk.setenv(field_env_key('overwrite_with_env_available', 'temperature', suffix='generation'), str(0.2))
|
||||
sent = make_llm_config(
|
||||
"OverwriteWithEnvAvailable", {
|
||||
"default_id": "asdfasdf", "model_ids": ["asdf", "asdfasdfads"], "architecture": "PreTrainedModel"
|
||||
}, fields=(("field1", "float", 3.0),),
|
||||
'OverwriteWithEnvAvailable', {
|
||||
'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'
|
||||
}, fields=(('field1', 'float', 3.0),),
|
||||
).model_construct_env(field1=20.0, temperature=0.4)
|
||||
assert sent.generation_config.temperature == 0.4
|
||||
assert sent.field1 == 20.0
|
||||
@given(model_settings())
|
||||
@pytest.mark.parametrize(("return_dict", "typ"), [(True, dict), (False, transformers.GenerationConfig)])
|
||||
@pytest.mark.parametrize(('return_dict', 'typ'), [(True, dict), (False, transformers.GenerationConfig)])
|
||||
def test_conversion_to_transformers(return_dict: bool, typ: type[t.Any], gen_settings: ModelSettings):
|
||||
cl_ = make_llm_config("ConversionLLM", gen_settings)
|
||||
cl_ = make_llm_config('ConversionLLM', gen_settings)
|
||||
assert isinstance(cl_().to_generation_config(return_as_dict=return_dict), typ)
|
||||
@given(model_settings())
|
||||
def test_click_conversion(gen_settings: ModelSettings):
|
||||
@@ -116,12 +116,12 @@ def test_click_conversion(gen_settings: ModelSettings):
|
||||
def cli_mock(**attrs: t.Any):
|
||||
return attrs
|
||||
|
||||
cl_ = make_llm_config("ClickConversionLLM", gen_settings)
|
||||
cl_ = make_llm_config('ClickConversionLLM', gen_settings)
|
||||
wrapped = cl_.to_click_options(cli_mock)
|
||||
filtered = {k for k, v in cl_.__openllm_hints__.items() if t.get_origin(v) is not t.Union}
|
||||
click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith("fake_")]
|
||||
click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith('fake_')]
|
||||
assert len(filtered) == len(click_options_filtered)
|
||||
@pytest.mark.parametrize("model_name", openllm.CONFIG_MAPPING.keys())
|
||||
@pytest.mark.parametrize('model_name', openllm.CONFIG_MAPPING.keys())
|
||||
def test_configuration_dict_protocol(model_name: str):
|
||||
config = openllm.AutoConfig.for_model(model_name)
|
||||
assert isinstance(config.items(), list)
|
||||
|
||||
@@ -2,21 +2,21 @@ from __future__ import annotations
|
||||
import itertools, os, typing as t, pytest, openllm
|
||||
if t.TYPE_CHECKING: from openllm_core._typing_compat import LiteralRuntime
|
||||
|
||||
_FRAMEWORK_MAPPING = {"flan_t5": "google/flan-t5-small", "opt": "facebook/opt-125m", "baichuan": "baichuan-inc/Baichuan-7B",}
|
||||
_PROMPT_MAPPING = {"qa": "Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?",}
|
||||
_FRAMEWORK_MAPPING = {'flan_t5': 'google/flan-t5-small', 'opt': 'facebook/opt-125m', 'baichuan': 'baichuan-inc/Baichuan-7B',}
|
||||
_PROMPT_MAPPING = {'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?',}
|
||||
def parametrise_local_llm(model: str,) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]:
|
||||
if model not in _FRAMEWORK_MAPPING: pytest.skip(f"'{model}' is not yet supported in framework testing.")
|
||||
runtime_impl: tuple[LiteralRuntime, ...] = tuple()
|
||||
if model in openllm.MODEL_MAPPING_NAMES: runtime_impl += ("pt",)
|
||||
if model in openllm.MODEL_FLAX_MAPPING_NAMES: runtime_impl += ("flax",)
|
||||
if model in openllm.MODEL_TF_MAPPING_NAMES: runtime_impl += ("tf",)
|
||||
if model in openllm.MODEL_MAPPING_NAMES: runtime_impl += ('pt',)
|
||||
if model in openllm.MODEL_FLAX_MAPPING_NAMES: runtime_impl += ('flax',)
|
||||
if model in openllm.MODEL_TF_MAPPING_NAMES: runtime_impl += ('tf',)
|
||||
for framework, prompt in itertools.product(runtime_impl, _PROMPT_MAPPING.keys()):
|
||||
llm = openllm.Runner(model, model_id=_FRAMEWORK_MAPPING[model], ensure_available=True, implementation=framework, init_local=True,)
|
||||
yield prompt, llm
|
||||
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
|
||||
if os.getenv("GITHUB_ACTIONS") is None:
|
||||
if "prompt" in metafunc.fixturenames and "llm" in metafunc.fixturenames:
|
||||
metafunc.parametrize("prompt,llm", [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])])
|
||||
if os.getenv('GITHUB_ACTIONS') is None:
|
||||
if 'prompt' in metafunc.fixturenames and 'llm' in metafunc.fixturenames:
|
||||
metafunc.parametrize('prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])])
|
||||
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
|
||||
# If no tests are collected, pytest exists with code 5, which makes the CI fail.
|
||||
if exitstatus == 5: session.exitstatus = 0
|
||||
|
||||
@@ -27,13 +27,13 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
try:
|
||||
data = orjson.loads(data)
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise ValueError(f"Failed to decode JSON data: {data}") from err
|
||||
raise ValueError(f'Failed to decode JSON data: {data}') from err
|
||||
if openllm.utils.LazyType(DictStrAny).isinstance(data):
|
||||
return openllm.GenerationOutput(**data)
|
||||
elif openllm.utils.LazyType(ListAny).isinstance(data):
|
||||
return [openllm.GenerationOutput(**d) for d in data]
|
||||
else:
|
||||
raise NotImplementedError(f"Data {data} has unsupported type.")
|
||||
raise NotImplementedError(f'Data {data} has unsupported type.')
|
||||
|
||||
serialized_data = convert_data(serialized_data)
|
||||
snapshot_data = convert_data(snapshot_data)
|
||||
@@ -56,7 +56,7 @@ def response_snapshot(snapshot: SnapshotAssertion):
|
||||
@attr.define(init=False)
|
||||
class _Handle(ABC):
|
||||
port: int
|
||||
deployment_mode: t.Literal["container", "local"]
|
||||
deployment_mode: t.Literal['container', 'local']
|
||||
|
||||
client: BaseAsyncClient[t.Any] = attr.field(init=False)
|
||||
|
||||
@@ -66,7 +66,7 @@ class _Handle(ABC):
|
||||
...
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.client = openllm.client.AsyncHTTPClient(f"http://localhost:{self.port}")
|
||||
self.client = openllm.client.AsyncHTTPClient(f'http://localhost:{self.port}')
|
||||
|
||||
@abstractmethod
|
||||
def status(self) -> bool:
|
||||
@@ -76,19 +76,19 @@ class _Handle(ABC):
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
if not self.status():
|
||||
raise RuntimeError(f"Failed to initialise {self.__class__.__name__}")
|
||||
raise RuntimeError(f'Failed to initialise {self.__class__.__name__}')
|
||||
await self.client.health()
|
||||
try:
|
||||
await self.client.query("sanity")
|
||||
await self.client.query('sanity')
|
||||
return
|
||||
except Exception:
|
||||
time.sleep(1)
|
||||
raise RuntimeError(f"Handle failed to initialise within {timeout} seconds.")
|
||||
raise RuntimeError(f'Handle failed to initialise within {timeout} seconds.')
|
||||
@attr.define(init=False)
|
||||
class LocalHandle(_Handle):
|
||||
process: subprocess.Popen[bytes]
|
||||
|
||||
def __init__(self, process: subprocess.Popen[bytes], port: int, deployment_mode: t.Literal["container", "local"],):
|
||||
def __init__(self, process: subprocess.Popen[bytes], port: int, deployment_mode: t.Literal['container', 'local'],):
|
||||
self.__attrs_init__(port, deployment_mode, process)
|
||||
|
||||
def status(self) -> bool:
|
||||
@@ -102,23 +102,23 @@ class DockerHandle(_Handle):
|
||||
container_name: str
|
||||
docker_client: docker.DockerClient
|
||||
|
||||
def __init__(self, docker_client: docker.DockerClient, container_name: str, port: int, deployment_mode: t.Literal["container", "local"],):
|
||||
def __init__(self, docker_client: docker.DockerClient, container_name: str, port: int, deployment_mode: t.Literal['container', 'local'],):
|
||||
self.__attrs_init__(port, deployment_mode, container_name, docker_client)
|
||||
|
||||
def status(self) -> bool:
|
||||
container = self.docker_client.containers.get(self.container_name)
|
||||
return container.status in ["running", "created"]
|
||||
return container.status in ['running', 'created']
|
||||
@contextlib.contextmanager
|
||||
def _local_handle(
|
||||
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal["container", "local"], quantize: t.Literal["int8", "int4", "gptq"] | None = None, *, _serve_grpc: bool = False,
|
||||
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal['container', 'local'], quantize: t.Literal['int8', 'int4', 'gptq'] | None = None, *, _serve_grpc: bool = False,
|
||||
):
|
||||
with openllm.utils.reserve_free_port() as port:
|
||||
pass
|
||||
|
||||
if not _serve_grpc:
|
||||
proc = openllm.start(model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True)
|
||||
proc = openllm.start(model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True)
|
||||
else:
|
||||
proc = openllm.start_grpc(model, model_id=model_id, quantize=quantize, additional_args=["--port", str(port)], __test__=True)
|
||||
proc = openllm.start_grpc(model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True)
|
||||
|
||||
yield LocalHandle(proc, port, deployment_mode)
|
||||
proc.terminate()
|
||||
@@ -132,13 +132,13 @@ def _local_handle(
|
||||
proc.stderr.close()
|
||||
@contextlib.contextmanager
|
||||
def _container_handle(
|
||||
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal["container", "local"], quantize: t.Literal["int8", "int4", "gptq"] | None = None, *, _serve_grpc: bool = False,
|
||||
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal['container', 'local'], quantize: t.Literal['int8', 'int4', 'gptq'] | None = None, *, _serve_grpc: bool = False,
|
||||
):
|
||||
envvar = openllm.utils.EnvVarMixin(model)
|
||||
|
||||
with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port:
|
||||
pass
|
||||
container_name = f"openllm-{model}-{normalise_model_name(model_id)}".replace("-", "_")
|
||||
container_name = f'openllm-{model}-{normalise_model_name(model_id)}'.replace('-', '_')
|
||||
client = docker.from_env()
|
||||
try:
|
||||
container = client.containers.get(container_name)
|
||||
@@ -148,7 +148,7 @@ def _container_handle(
|
||||
except docker.errors.NotFound:
|
||||
pass
|
||||
|
||||
args = ["serve" if not _serve_grpc else "serve-grpc"]
|
||||
args = ['serve' if not _serve_grpc else 'serve-grpc']
|
||||
|
||||
env: DictStrAny = {}
|
||||
|
||||
@@ -156,11 +156,11 @@ def _container_handle(
|
||||
env[envvar.quantize] = quantize
|
||||
|
||||
gpus = openllm.utils.device_count() or -1
|
||||
devs = [docker.types.DeviceRequest(count=gpus, capabilities=[["gpu"]])] if gpus > 0 else None
|
||||
devs = [docker.types.DeviceRequest(count=gpus, capabilities=[['gpu']])] if gpus > 0 else None
|
||||
|
||||
container = client.containers.run(
|
||||
image_tag, command=args, name=container_name, environment=env, auto_remove=False, detach=True, device_requests=devs, ports={
|
||||
"3000/tcp": port, "3001/tcp": prom_port
|
||||
'3000/tcp': port, '3001/tcp': prom_port
|
||||
},
|
||||
)
|
||||
|
||||
@@ -172,28 +172,28 @@ def _container_handle(
|
||||
except docker.errors.NotFound:
|
||||
pass
|
||||
|
||||
container_output = container.logs().decode("utf-8")
|
||||
container_output = container.logs().decode('utf-8')
|
||||
print(container_output, file=sys.stderr)
|
||||
|
||||
container.remove()
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@pytest.fixture(scope='session', autouse=True)
|
||||
def clean_context() -> t.Generator[contextlib.ExitStack, None, None]:
|
||||
stack = contextlib.ExitStack()
|
||||
yield stack
|
||||
stack.close()
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope='module')
|
||||
def el() -> t.Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
loop = asyncio.get_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
@pytest.fixture(params=["container", "local"], scope="session")
|
||||
@pytest.fixture(params=['container', 'local'], scope='session')
|
||||
def deployment_mode(request: pytest.FixtureRequest) -> str:
|
||||
return request.param
|
||||
@pytest.fixture(scope="module")
|
||||
def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal["container", "local"]):
|
||||
if deployment_mode == "container":
|
||||
@pytest.fixture(scope='module')
|
||||
def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal['container', 'local']):
|
||||
if deployment_mode == 'container':
|
||||
return functools.partial(_container_handle, deployment_mode=deployment_mode)
|
||||
elif deployment_mode == "local":
|
||||
elif deployment_mode == 'local':
|
||||
return functools.partial(_local_handle, deployment_mode=deployment_mode)
|
||||
else:
|
||||
raise ValueError(f"Unknown deployment mode: {deployment_mode}")
|
||||
raise ValueError(f'Unknown deployment mode: {deployment_mode}')
|
||||
|
||||
@@ -9,21 +9,21 @@ if t.TYPE_CHECKING:
|
||||
import contextlib
|
||||
|
||||
from .conftest import HandleProtocol, ResponseComparator, _Handle
|
||||
model = "flan_t5"
|
||||
model_id = "google/flan-t5-small"
|
||||
@pytest.fixture(scope="module")
|
||||
def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal["container", "local"], clean_context: contextlib.ExitStack,):
|
||||
model = 'flan_t5'
|
||||
model_id = 'google/flan-t5-small'
|
||||
@pytest.fixture(scope='module')
|
||||
def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack,):
|
||||
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
|
||||
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
|
||||
yield handle
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope='module')
|
||||
async def flan_t5(flan_t5_handle: _Handle):
|
||||
await flan_t5_handle.health(240)
|
||||
return flan_t5_handle.client
|
||||
@pytest.mark.asyncio()
|
||||
async def test_flan_t5(flan_t5: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
|
||||
client = await flan_t5
|
||||
response = await client.query("What is the meaning of life?", max_new_tokens=10, top_p=0.9, return_response="attrs")
|
||||
response = await client.query('What is the meaning of life?', max_new_tokens=10, top_p=0.9, return_response='attrs')
|
||||
|
||||
assert response.configuration["generation_config"]["max_new_tokens"] == 10
|
||||
assert response.configuration['generation_config']['max_new_tokens'] == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
@@ -9,21 +9,21 @@ if t.TYPE_CHECKING:
|
||||
import contextlib
|
||||
|
||||
from .conftest import HandleProtocol, ResponseComparator, _Handle
|
||||
model = "opt"
|
||||
model_id = "facebook/opt-125m"
|
||||
@pytest.fixture(scope="module")
|
||||
def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal["container", "local"], clean_context: contextlib.ExitStack,):
|
||||
model = 'opt'
|
||||
model_id = 'facebook/opt-125m'
|
||||
@pytest.fixture(scope='module')
|
||||
def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack,):
|
||||
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
|
||||
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
|
||||
yield handle
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope='module')
|
||||
async def opt_125m(opt_125m_handle: _Handle):
|
||||
await opt_125m_handle.health(240)
|
||||
return opt_125m_handle.client
|
||||
@pytest.mark.asyncio()
|
||||
async def test_opt_125m(opt_125m: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
|
||||
client = await opt_125m
|
||||
response = await client.query("What is Deep learning?", max_new_tokens=20, return_response="attrs")
|
||||
response = await client.query('What is Deep learning?', max_new_tokens=20, return_response='attrs')
|
||||
|
||||
assert response.configuration["generation_config"]["max_new_tokens"] == 20
|
||||
assert response.configuration['generation_config']['max_new_tokens'] == 20
|
||||
assert response == response_snapshot
|
||||
|
||||
@@ -2,17 +2,17 @@ from __future__ import annotations
|
||||
import os, typing as t, pytest
|
||||
|
||||
if t.TYPE_CHECKING: import openllm
|
||||
@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="Model is too large for CI")
|
||||
@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI')
|
||||
def test_flan_t5_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
|
||||
assert llm(prompt)
|
||||
|
||||
assert llm(prompt, temperature=0.8, top_p=0.23)
|
||||
@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="Model is too large for CI")
|
||||
@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI')
|
||||
def test_opt_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
|
||||
assert llm(prompt)
|
||||
|
||||
assert llm(prompt, temperature=0.9, top_k=8)
|
||||
@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="Model is too large for CI")
|
||||
@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI')
|
||||
def test_baichuan_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
|
||||
assert llm(prompt)
|
||||
|
||||
|
||||
@@ -3,40 +3,40 @@ import functools, os, typing as t, pytest, openllm
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
if t.TYPE_CHECKING: from pathlib import Path
|
||||
|
||||
HF_INTERNAL_T5_TESTING = "hf-internal-testing/tiny-random-t5"
|
||||
HF_INTERNAL_T5_TESTING = 'hf-internal-testing/tiny-random-t5'
|
||||
|
||||
actions_xfail = functools.partial(
|
||||
pytest.mark.xfail, condition=os.getenv("GITHUB_ACTIONS") is not None, reason="Marking GitHub Actions to xfail due to flakiness and building environment not isolated.",
|
||||
pytest.mark.xfail, condition=os.getenv('GITHUB_ACTIONS') is not None, reason='Marking GitHub Actions to xfail due to flakiness and building environment not isolated.',
|
||||
)
|
||||
@actions_xfail
|
||||
def test_general_build_with_internal_testing():
|
||||
bento_store = BentoMLContainer.bento_store.get()
|
||||
|
||||
llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING)
|
||||
bento = openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING)
|
||||
llm = openllm.AutoLLM.for_model('flan-t5', model_id=HF_INTERNAL_T5_TESTING)
|
||||
bento = openllm.build('flan-t5', model_id=HF_INTERNAL_T5_TESTING)
|
||||
|
||||
assert llm.llm_type == bento.info.labels["_type"]
|
||||
assert llm.config["env"]["framework_value"] == bento.info.labels["_framework"]
|
||||
assert llm.llm_type == bento.info.labels['_type']
|
||||
assert llm.config['env']['framework_value'] == bento.info.labels['_framework']
|
||||
|
||||
bento = openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING)
|
||||
bento = openllm.build('flan-t5', model_id=HF_INTERNAL_T5_TESTING)
|
||||
assert len(bento_store.list(bento.tag)) == 1
|
||||
@actions_xfail
|
||||
def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory):
|
||||
local_path = tmp_path_factory.mktemp("local_t5")
|
||||
llm = openllm.AutoLLM.for_model("flan-t5", model_id=HF_INTERNAL_T5_TESTING, ensure_available=True)
|
||||
local_path = tmp_path_factory.mktemp('local_t5')
|
||||
llm = openllm.AutoLLM.for_model('flan-t5', model_id=HF_INTERNAL_T5_TESTING, ensure_available=True)
|
||||
|
||||
if llm.bettertransformer:
|
||||
llm.__llm_model__ = llm.model.reverse_bettertransformer()
|
||||
|
||||
llm.save_pretrained(local_path)
|
||||
|
||||
assert openllm.build("flan-t5", model_id=local_path.resolve().__fspath__(), model_version="local")
|
||||
assert openllm.build('flan-t5', model_id=local_path.resolve().__fspath__(), model_version='local')
|
||||
@pytest.fixture()
|
||||
def dockerfile_template(tmp_path_factory: pytest.TempPathFactory):
|
||||
file = tmp_path_factory.mktemp("dockerfiles") / "Dockerfile.template"
|
||||
file = tmp_path_factory.mktemp('dockerfiles') / 'Dockerfile.template'
|
||||
file.write_text("{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}")
|
||||
return file
|
||||
@pytest.mark.usefixtures("dockerfile_template")
|
||||
@pytest.mark.usefixtures('dockerfile_template')
|
||||
@actions_xfail
|
||||
def test_build_with_custom_dockerfile(dockerfile_template: Path):
|
||||
assert openllm.build("flan-t5", model_id=HF_INTERNAL_T5_TESTING, dockerfile_template=str(dockerfile_template))
|
||||
assert openllm.build('flan-t5', model_id=HF_INTERNAL_T5_TESTING, dockerfile_template=str(dockerfile_template))
|
||||
|
||||
@@ -5,79 +5,79 @@ from openllm_core._strategies import CascadingResourceStrategy, NvidiaGpuResourc
|
||||
if t.TYPE_CHECKING: from _pytest.monkeypatch import MonkeyPatch
|
||||
def test_nvidia_gpu_resource_from_env(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv("CUDA_VISIBLE_DEVICES", "0,1")
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '0,1')
|
||||
resource = NvidiaGpuResource.from_system()
|
||||
assert len(resource) == 2
|
||||
assert resource == ["0", "1"]
|
||||
mcls.delenv("CUDA_VISIBLE_DEVICES")
|
||||
assert resource == ['0', '1']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
def test_nvidia_gpu_cutoff_minus(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv("CUDA_VISIBLE_DEVICES", "0,2,-1,1")
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '0,2,-1,1')
|
||||
resource = NvidiaGpuResource.from_system()
|
||||
assert len(resource) == 2
|
||||
assert resource == ["0", "2"]
|
||||
mcls.delenv("CUDA_VISIBLE_DEVICES")
|
||||
assert resource == ['0', '2']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
def test_nvidia_gpu_neg_val(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv("CUDA_VISIBLE_DEVICES", "-1")
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '-1')
|
||||
resource = NvidiaGpuResource.from_system()
|
||||
assert len(resource) == 0
|
||||
assert resource == []
|
||||
mcls.delenv("CUDA_VISIBLE_DEVICES")
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
def test_nvidia_gpu_parse_literal(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43-ac33420d4628")
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', 'GPU-5ebe9f43-ac33420d4628')
|
||||
resource = NvidiaGpuResource.from_system()
|
||||
assert len(resource) == 1
|
||||
assert resource == ["GPU-5ebe9f43-ac33420d4628"]
|
||||
mcls.delenv("CUDA_VISIBLE_DEVICES")
|
||||
assert resource == ['GPU-5ebe9f43-ac33420d4628']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43,GPU-ac33420d4628")
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', 'GPU-5ebe9f43,GPU-ac33420d4628')
|
||||
resource = NvidiaGpuResource.from_system()
|
||||
assert len(resource) == 2
|
||||
assert resource == ["GPU-5ebe9f43", "GPU-ac33420d4628"]
|
||||
mcls.delenv("CUDA_VISIBLE_DEVICES")
|
||||
assert resource == ['GPU-5ebe9f43', 'GPU-ac33420d4628']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv("CUDA_VISIBLE_DEVICES", "GPU-5ebe9f43,-1,GPU-ac33420d4628")
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', 'GPU-5ebe9f43,-1,GPU-ac33420d4628')
|
||||
resource = NvidiaGpuResource.from_system()
|
||||
assert len(resource) == 1
|
||||
assert resource == ["GPU-5ebe9f43"]
|
||||
mcls.delenv("CUDA_VISIBLE_DEVICES")
|
||||
assert resource == ['GPU-5ebe9f43']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv("CUDA_VISIBLE_DEVICES", "MIG-GPU-5ebe9f43-ac33420d4628")
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', 'MIG-GPU-5ebe9f43-ac33420d4628')
|
||||
resource = NvidiaGpuResource.from_system()
|
||||
assert len(resource) == 1
|
||||
assert resource == ["MIG-GPU-5ebe9f43-ac33420d4628"]
|
||||
mcls.delenv("CUDA_VISIBLE_DEVICES")
|
||||
@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") is not None, reason="skip GPUs test on CI")
|
||||
assert resource == ['MIG-GPU-5ebe9f43-ac33420d4628']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='skip GPUs test on CI')
|
||||
def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
# to make this tests works with system that has GPU
|
||||
mcls.setenv("CUDA_VISIBLE_DEVICES", "")
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '')
|
||||
assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests
|
||||
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1],).match("Input list should be all string type.")
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match("Input list should be all string type.")
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, ["GPU-5ebe9f43", "GPU-ac33420d4628"]).match("Failed to parse available GPUs UUID")
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1],).match('Input list should be all string type.')
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match('Input list should be all string type.')
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match('Failed to parse available GPUs UUID')
|
||||
def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
# to make this tests works with system that has GPU
|
||||
mcls.setenv("CUDA_VISIBLE_DEVICES", "")
|
||||
assert NvidiaGpuResource.from_spec(1) == ["0"]
|
||||
assert NvidiaGpuResource.from_spec("5") == ["0", "1", "2", "3", "4"]
|
||||
assert NvidiaGpuResource.from_spec(1) == ["0"]
|
||||
assert NvidiaGpuResource.from_spec(2) == ["0", "1"]
|
||||
assert NvidiaGpuResource.from_spec("3") == ["0", "1", "2"]
|
||||
assert NvidiaGpuResource.from_spec([1, 3]) == ["1", "3"]
|
||||
assert NvidiaGpuResource.from_spec(["1", "3"]) == ["1", "3"]
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '')
|
||||
assert NvidiaGpuResource.from_spec(1) == ['0']
|
||||
assert NvidiaGpuResource.from_spec('5') == ['0', '1', '2', '3', '4']
|
||||
assert NvidiaGpuResource.from_spec(1) == ['0']
|
||||
assert NvidiaGpuResource.from_spec(2) == ['0', '1']
|
||||
assert NvidiaGpuResource.from_spec('3') == ['0', '1', '2']
|
||||
assert NvidiaGpuResource.from_spec([1, 3]) == ['1', '3']
|
||||
assert NvidiaGpuResource.from_spec(['1', '3']) == ['1', '3']
|
||||
assert NvidiaGpuResource.from_spec(-1) == []
|
||||
assert NvidiaGpuResource.from_spec("-1") == []
|
||||
assert NvidiaGpuResource.from_spec("") == []
|
||||
assert NvidiaGpuResource.from_spec("-2") == []
|
||||
assert NvidiaGpuResource.from_spec("GPU-288347ab") == ["GPU-288347ab"]
|
||||
assert NvidiaGpuResource.from_spec("GPU-288347ab,-1,GPU-ac33420d4628") == ["GPU-288347ab"]
|
||||
assert NvidiaGpuResource.from_spec("GPU-288347ab,GPU-ac33420d4628") == ["GPU-288347ab", "GPU-ac33420d4628"]
|
||||
assert NvidiaGpuResource.from_spec("MIG-GPU-288347ab") == ["MIG-GPU-288347ab"]
|
||||
assert NvidiaGpuResource.from_spec('-1') == []
|
||||
assert NvidiaGpuResource.from_spec('') == []
|
||||
assert NvidiaGpuResource.from_spec('-2') == []
|
||||
assert NvidiaGpuResource.from_spec('GPU-288347ab') == ['GPU-288347ab']
|
||||
assert NvidiaGpuResource.from_spec('GPU-288347ab,-1,GPU-ac33420d4628') == ['GPU-288347ab']
|
||||
assert NvidiaGpuResource.from_spec('GPU-288347ab,GPU-ac33420d4628') == ['GPU-288347ab', 'GPU-ac33420d4628']
|
||||
assert NvidiaGpuResource.from_spec('MIG-GPU-288347ab') == ['MIG-GPU-288347ab']
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
NvidiaGpuResource.from_spec((1, 2, 3))
|
||||
@@ -86,12 +86,12 @@ def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):
|
||||
with pytest.raises(ValueError):
|
||||
assert NvidiaGpuResource.from_spec(-2)
|
||||
class GPURunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "amd.com/gpu")
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu')
|
||||
def unvalidated_get_resource(x: dict[str, t.Any], y: str, validate: bool = False):
|
||||
return get_resource(x, y, validate=validate)
|
||||
@pytest.mark.parametrize("gpu_type", ["nvidia.com/gpu", "amd.com/gpu"])
|
||||
@pytest.mark.parametrize('gpu_type', ['nvidia.com/gpu', 'amd.com/gpu'])
|
||||
def test_cascade_strategy_worker_count(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource)
|
||||
monkeypatch.setattr(strategy, 'get_resource', unvalidated_get_resource)
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: 2}, 1) == 1
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7]}, 1) == 1
|
||||
|
||||
@@ -99,54 +99,54 @@ def test_cascade_strategy_worker_count(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 9]}, 0.5) == 1
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5) == 1
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 5, 7, 8, 9]}, 0.4) == 1
|
||||
@pytest.mark.parametrize("gpu_type", ["nvidia.com/gpu", "amd.com/gpu"])
|
||||
@pytest.mark.parametrize('gpu_type', ['nvidia.com/gpu', 'amd.com/gpu'])
|
||||
def test_cascade_strategy_worker_env(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource)
|
||||
monkeypatch.setattr(strategy, 'get_resource', unvalidated_get_resource)
|
||||
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 0)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "0"
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '0'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 1)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "1"
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '1'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 1, 1)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "7"
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '7'
|
||||
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 0)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "0"
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '0'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 1)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "0"
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '0'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 2, 2)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "1"
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '1'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 2, 1)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "2"
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '2'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 2, 2)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "7"
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '7'
|
||||
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7]}, 0.5, 0)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7"
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '2,7'
|
||||
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5, 0)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7"
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '2,7'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5, 1)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "8,9"
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '8,9'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.25, 0)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "2,7,8,9"
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '2,7,8,9'
|
||||
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 0)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "2,6"
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '2,6'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 1)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "7,8"
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '7,8'
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 2)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "9"
|
||||
@pytest.mark.parametrize("gpu_type", ["nvidia.com/gpu", "amd.com/gpu"])
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '9'
|
||||
@pytest.mark.parametrize('gpu_type', ['nvidia.com/gpu', 'amd.com/gpu'])
|
||||
def test_cascade_strategy_disabled_via_env(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
monkeypatch.setattr(strategy, "get_resource", unvalidated_get_resource)
|
||||
monkeypatch.setattr(strategy, 'get_resource', unvalidated_get_resource)
|
||||
|
||||
monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "")
|
||||
monkeypatch.setenv('CUDA_VISIBLE_DEVICES', '')
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 0)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == ""
|
||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES")
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == ''
|
||||
monkeypatch.delenv('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "-1")
|
||||
monkeypatch.setenv('CUDA_VISIBLE_DEVICES', '-1')
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: 2}, 1, 1)
|
||||
assert envs.get("CUDA_VISIBLE_DEVICES") == "-1"
|
||||
monkeypatch.delenv("CUDA_VISIBLE_DEVICES")
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '-1'
|
||||
monkeypatch.delenv('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
Reference in New Issue
Block a user