style: google

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-08-30 13:52:00 -04:00
parent e2ba6a92a6
commit b545ad2ad1
98 changed files with 3514 additions and 2094 deletions

View File

@@ -16,19 +16,32 @@ env_strats = st.sampled_from([openllm.utils.EnvVarMixin(model_name) for model_na
def model_settings(draw: st.DrawFn):
'''Strategy for generating ModelSettings objects.'''
kwargs: dict[str, t.Any] = {
'default_id': st.text(min_size=1),
'model_ids': st.lists(st.text(), min_size=1),
'architecture': st.text(min_size=1),
'url': st.text(),
'requires_gpu': st.booleans(),
'trust_remote_code': st.booleans(),
'requirements': st.none() | st.lists(st.text(), min_size=1),
'default_implementation': st.dictionaries(st.sampled_from(['cpu', 'nvidia.com/gpu']), st.sampled_from(['vllm', 'pt', 'tf', 'flax'])),
'model_type': st.sampled_from(['causal_lm', 'seq2seq_lm']),
'runtime': st.sampled_from(['transformers', 'ggml']),
'name_type': st.sampled_from(['dasherize', 'lowercase']),
'timeout': st.integers(min_value=3600),
'workers_per_resource': st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)),
'default_id':
st.text(min_size=1),
'model_ids':
st.lists(st.text(), min_size=1),
'architecture':
st.text(min_size=1),
'url':
st.text(),
'requires_gpu':
st.booleans(),
'trust_remote_code':
st.booleans(),
'requirements':
st.none() | st.lists(st.text(), min_size=1),
'default_implementation':
st.dictionaries(st.sampled_from(['cpu', 'nvidia.com/gpu']), st.sampled_from(['vllm', 'pt', 'tf', 'flax'])),
'model_type':
st.sampled_from(['causal_lm', 'seq2seq_lm']),
'runtime':
st.sampled_from(['transformers', 'ggml']),
'name_type':
st.sampled_from(['dasherize', 'lowercase']),
'timeout':
st.integers(min_value=3600),
'workers_per_resource':
st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)),
}
return draw(st.builds(ModelSettings, **kwargs))

View File

@@ -24,19 +24,29 @@ from ._strategies._configuration import make_llm_config
from ._strategies._configuration import model_settings
# XXX: @aarnphm fixes TypedDict behaviour in 3.11
@pytest.mark.skipif(sys.version_info[:2] == (3, 11), reason='TypedDict in 3.11 behaves differently, so we need to fix this')
@pytest.mark.skipif(sys.version_info[:2] == (3, 11),
reason='TypedDict in 3.11 behaves differently, so we need to fix this')
def test_missing_default():
with pytest.raises(ValueError, match='Missing required fields *'):
make_llm_config('MissingDefaultId', {'name_type': 'lowercase', 'requirements': ['bentoml']})
with pytest.raises(ValueError, match='Missing required fields *'):
make_llm_config('MissingModelId', {'default_id': 'huggingface/t5-tiny-testing', 'requirements': ['bentoml']})
with pytest.raises(ValueError, match='Missing required fields *'):
make_llm_config('MissingArchitecture', {'default_id': 'huggingface/t5-tiny-testing', 'model_ids': ['huggingface/t5-tiny-testing'], 'requirements': ['bentoml'],},)
make_llm_config(
'MissingArchitecture', {
'default_id': 'huggingface/t5-tiny-testing',
'model_ids': ['huggingface/t5-tiny-testing'],
'requirements': ['bentoml'],
},
)
def test_forbidden_access():
cl_ = make_llm_config(
'ForbiddenAccess', {
'default_id': 'huggingface/t5-tiny-testing', 'model_ids': ['huggingface/t5-tiny-testing', 'bentoml/t5-tiny-testing'], 'architecture': 'PreTrainedModel', 'requirements': ['bentoml'],
'default_id': 'huggingface/t5-tiny-testing',
'model_ids': ['huggingface/t5-tiny-testing', 'bentoml/t5-tiny-testing'],
'architecture': 'PreTrainedModel',
'requirements': ['bentoml'],
},
)
@@ -69,9 +79,16 @@ def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings):
cl_ = make_llm_config('AttrsProtocolLLM', gen_settings)
assert attr.has(cl_)
@given(model_settings(), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0),)
def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float):
cl_ = make_llm_config('ComplexLLM', gen_settings, fields=(('field1', 'float', field1),), generation_fields=(('temperature', temperature),),)
@given(model_settings(), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0),
st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0),
)
def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int,
input_temperature: float):
cl_ = make_llm_config('ComplexLLM',
gen_settings,
fields=(('field1', 'float', field1),),
generation_fields=(('temperature', temperature),),
)
sent = cl_()
assert sent.model_dump()['field1'] == field1
assert sent.model_dump()['generation_config']['temperature'] == temperature
@@ -94,7 +111,10 @@ def patch_env(**attrs: t.Any):
yield
def test_struct_envvar():
with patch_env(**{field_env_key('env_llm', 'field1'): '4', field_env_key('env_llm', 'temperature', suffix='generation'): '0.2',}):
with patch_env(**{
field_env_key('env_llm', 'field1'): '4',
field_env_key('env_llm', 'temperature', suffix='generation'): '0.2',
}):
class EnvLLM(openllm.LLMConfig):
__config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel',}
@@ -112,6 +132,7 @@ def test_struct_envvar():
assert overwrite_default['temperature'] == 0.2
def test_struct_provided_fields():
class EnvLLM(openllm.LLMConfig):
__config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel',}
field1: int = 2
@@ -127,11 +148,13 @@ def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPat
with monkeypatch.context() as mk:
mk.setenv(field_env_key('overwrite_with_env_available', 'field1'), str(4.0))
mk.setenv(field_env_key('overwrite_with_env_available', 'temperature', suffix='generation'), str(0.2))
sent = make_llm_config(
'OverwriteWithEnvAvailable', {
'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'
}, fields=(('field1', 'float', 3.0),),
).model_construct_env(field1=20.0, temperature=0.4)
sent = make_llm_config('OverwriteWithEnvAvailable', {
'default_id': 'asdfasdf',
'model_ids': ['asdf', 'asdfasdfads'],
'architecture': 'PreTrainedModel'
},
fields=(('field1', 'float', 3.0),),
).model_construct_env(field1=20.0, temperature=0.4)
assert sent.generation_config.temperature == 0.4
assert sent.field1 == 20.0

View File

@@ -10,23 +10,37 @@ import openllm
if t.TYPE_CHECKING:
from openllm_core._typing_compat import LiteralRuntime
_FRAMEWORK_MAPPING = {'flan_t5': 'google/flan-t5-small', 'opt': 'facebook/opt-125m', 'baichuan': 'baichuan-inc/Baichuan-7B',}
_PROMPT_MAPPING = {'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?',}
_FRAMEWORK_MAPPING = {
'flan_t5': 'google/flan-t5-small',
'opt': 'facebook/opt-125m',
'baichuan': 'baichuan-inc/Baichuan-7B',
}
_PROMPT_MAPPING = {
'qa':
'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?',
}
def parametrise_local_llm(model: str,) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]:
def parametrise_local_llm(
model: str,) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]:
if model not in _FRAMEWORK_MAPPING: pytest.skip(f"'{model}' is not yet supported in framework testing.")
runtime_impl: tuple[LiteralRuntime, ...] = tuple()
if model in openllm.MODEL_MAPPING_NAMES: runtime_impl += ('pt',)
if model in openllm.MODEL_FLAX_MAPPING_NAMES: runtime_impl += ('flax',)
if model in openllm.MODEL_TF_MAPPING_NAMES: runtime_impl += ('tf',)
for framework, prompt in itertools.product(runtime_impl, _PROMPT_MAPPING.keys()):
llm = openllm.Runner(model, model_id=_FRAMEWORK_MAPPING[model], ensure_available=True, implementation=framework, init_local=True,)
llm = openllm.Runner(model,
model_id=_FRAMEWORK_MAPPING[model],
ensure_available=True,
implementation=framework,
init_local=True,
)
yield prompt, llm
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
if os.getenv('GITHUB_ACTIONS') is None:
if 'prompt' in metafunc.fixturenames and 'llm' in metafunc.fixturenames:
metafunc.parametrize('prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])])
metafunc.parametrize('prompt,llm',
[(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])])
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
# If no tests are collected, pytest exists with code 5, which makes the CI fail.

View File

@@ -40,7 +40,13 @@ if t.TYPE_CHECKING:
from openllm.client import BaseAsyncClient
class ResponseComparator(JSONSnapshotExtension):
def serialize(self, data: SerializableData, *, exclude: PropertyFilter | None = None, matcher: PropertyMatcher | None = None,) -> SerializedData:
def serialize(self,
data: SerializableData,
*,
exclude: PropertyFilter | None = None,
matcher: PropertyMatcher | None = None,
) -> SerializedData:
if openllm.utils.LazyType(ListAny).isinstance(data):
data = [d.unmarshaled for d in data]
else:
@@ -49,6 +55,7 @@ class ResponseComparator(JSONSnapshotExtension):
return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()
def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool:
def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]:
try:
data = orjson.loads(data)
@@ -73,9 +80,11 @@ class ResponseComparator(JSONSnapshotExtension):
return s == t
def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool:
return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and eq_config(s.marshaled_config, t.marshaled_config))
return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and
eq_config(s.marshaled_config, t.marshaled_config))
return len(serialized_data) == len(snapshot_data) and all([eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])
return len(serialized_data) == len(snapshot_data) and all(
[eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])
@pytest.fixture()
def response_snapshot(snapshot: SnapshotAssertion):
@@ -124,8 +133,14 @@ class LocalHandle(_Handle):
return self.process.poll() is None
class HandleProtocol(t.Protocol):
@contextlib.contextmanager
def __call__(*, model: str, model_id: str, image_tag: str, quantize: t.AnyStr | None = None,) -> t.Generator[_Handle, None, None]:
def __call__(*,
model: str,
model_id: str,
image_tag: str,
quantize: t.AnyStr | None = None,
) -> t.Generator[_Handle, None, None]:
...
@attr.define(init=False)
@@ -133,7 +148,9 @@ class DockerHandle(_Handle):
container_name: str
docker_client: docker.DockerClient
def __init__(self, docker_client: docker.DockerClient, container_name: str, port: int, deployment_mode: t.Literal['container', 'local'],):
def __init__(self, docker_client: docker.DockerClient, container_name: str, port: int,
deployment_mode: t.Literal['container', 'local'],
):
self.__attrs_init__(port, deployment_mode, container_name, docker_client)
def status(self) -> bool:
@@ -141,16 +158,29 @@ class DockerHandle(_Handle):
return container.status in ['running', 'created']
@contextlib.contextmanager
def _local_handle(
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal['container', 'local'], quantize: t.Literal['int8', 'int4', 'gptq'] | None = None, *, _serve_grpc: bool = False,
):
def _local_handle(model: str,
model_id: str,
image_tag: str,
deployment_mode: t.Literal['container', 'local'],
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
*,
_serve_grpc: bool = False,
):
with openllm.utils.reserve_free_port() as port:
pass
if not _serve_grpc:
proc = openllm.start(model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True)
proc = openllm.start(model,
model_id=model_id,
quantize=quantize,
additional_args=['--port', str(port)],
__test__=True)
else:
proc = openllm.start_grpc(model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True)
proc = openllm.start_grpc(model,
model_id=model_id,
quantize=quantize,
additional_args=['--port', str(port)],
__test__=True)
yield LocalHandle(proc, port, deployment_mode)
proc.terminate()
@@ -164,9 +194,14 @@ def _local_handle(
proc.stderr.close()
@contextlib.contextmanager
def _container_handle(
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal['container', 'local'], quantize: t.Literal['int8', 'int4', 'gptq'] | None = None, *, _serve_grpc: bool = False,
):
def _container_handle(model: str,
model_id: str,
image_tag: str,
deployment_mode: t.Literal['container', 'local'],
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
*,
_serve_grpc: bool = False,
):
envvar = openllm.utils.EnvVarMixin(model)
with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port:
@@ -191,11 +226,18 @@ def _container_handle(
gpus = openllm.utils.device_count() or -1
devs = [docker.types.DeviceRequest(count=gpus, capabilities=[['gpu']])] if gpus > 0 else None
container = client.containers.run(
image_tag, command=args, name=container_name, environment=env, auto_remove=False, detach=True, device_requests=devs, ports={
'3000/tcp': port, '3001/tcp': prom_port
},
)
container = client.containers.run(image_tag,
command=args,
name=container_name,
environment=env,
auto_remove=False,
detach=True,
device_requests=devs,
ports={
'3000/tcp': port,
'3001/tcp': prom_port
},
)
yield DockerHandle(client, container.name, port, deployment_mode)

View File

@@ -16,8 +16,11 @@ model = 'flan_t5'
model_id = 'google/flan-t5-small'
@pytest.fixture(scope='module')
def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack,):
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'],
clean_context: contextlib.ExitStack,
):
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode,
clean_context=clean_context) as image_tag:
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
yield handle

View File

@@ -16,8 +16,11 @@ model = 'opt'
model_id = 'facebook/opt-125m'
@pytest.fixture(scope='module')
def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack,):
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'],
clean_context: contextlib.ExitStack,
):
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode,
clean_context=clean_context) as image_tag:
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
yield handle

View File

@@ -15,7 +15,9 @@ if t.TYPE_CHECKING:
HF_INTERNAL_T5_TESTING = 'hf-internal-testing/tiny-random-t5'
actions_xfail = functools.partial(
pytest.mark.xfail, condition=os.getenv('GITHUB_ACTIONS') is not None, reason='Marking GitHub Actions to xfail due to flakiness and building environment not isolated.',
pytest.mark.xfail,
condition=os.getenv('GITHUB_ACTIONS') is not None,
reason='Marking GitHub Actions to xfail due to flakiness and building environment not isolated.',
)
@actions_xfail
@@ -46,7 +48,9 @@ def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory):
@pytest.fixture()
def dockerfile_template(tmp_path_factory: pytest.TempPathFactory):
file = tmp_path_factory.mktemp('dockerfiles') / 'Dockerfile.template'
file.write_text("{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}")
file.write_text(
"{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}"
)
return file
@pytest.mark.usefixtures('dockerfile_template')

View File

@@ -71,9 +71,11 @@ def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch):
mcls.setenv('CUDA_VISIBLE_DEVICES', '')
assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1],).match('Input list should be all string type.')
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1],
).match('Input list should be all string type.')
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match('Input list should be all string type.')
assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match('Failed to parse available GPUs UUID')
assert pytest.raises(ValueError, NvidiaGpuResource.validate,
['GPU-5ebe9f43', 'GPU-ac33420d4628']).match('Failed to parse available GPUs UUID')
def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as mcls: