mirror of
https://github.com/bentoml/OpenLLM.git
synced 2025-12-23 23:57:46 -05:00
153 lines
5.6 KiB
Python
153 lines
5.6 KiB
Python
from __future__ import annotations
|
|
import contextlib
|
|
import os
|
|
import typing as t
|
|
from unittest import mock
|
|
|
|
import attr
|
|
import pytest
|
|
from hypothesis import assume, given, strategies as st
|
|
|
|
import openllm
|
|
from openllm_core._configuration import GenerationConfig, ModelSettings, field_env_key
|
|
|
|
from ._strategies._configuration import make_llm_config, model_settings
|
|
|
|
|
|
def test_forbidden_access():
|
|
cl_ = make_llm_config(
|
|
'ForbiddenAccess',
|
|
{
|
|
'default_id': 'huggingface/t5-tiny-testing',
|
|
'model_ids': ['huggingface/t5-tiny-testing', 'bentoml/t5-tiny-testing'],
|
|
'architecture': 'PreTrainedModel',
|
|
'requirements': ['bentoml'],
|
|
},
|
|
)
|
|
|
|
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), '__config__')
|
|
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), 'GenerationConfig')
|
|
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), 'SamplingParams')
|
|
assert openllm.utils.lenient_issubclass(cl_.__openllm_generation_class__, GenerationConfig)
|
|
|
|
|
|
@given(model_settings())
|
|
def test_class_normal_gen(gen_settings: ModelSettings):
|
|
assume(gen_settings['default_id'] and all(i for i in gen_settings['model_ids']))
|
|
cl_: type[openllm.LLMConfig] = make_llm_config('NotFullLLM', gen_settings)
|
|
assert issubclass(cl_, openllm.LLMConfig)
|
|
for key in gen_settings:
|
|
assert object.__getattribute__(cl_, f'__openllm_{key}__') == gen_settings.__getitem__(key)
|
|
|
|
|
|
@given(model_settings(), st.integers())
|
|
def test_simple_struct_dump(gen_settings: ModelSettings, field1: int):
|
|
cl_ = make_llm_config('IdempotentLLM', gen_settings, fields=(('field1', 'float', field1),))
|
|
assert cl_().model_dump()['field1'] == field1
|
|
|
|
|
|
@given(model_settings(), st.integers())
|
|
def test_config_derivation(gen_settings: ModelSettings, field1: int):
|
|
cl_ = make_llm_config('IdempotentLLM', gen_settings, fields=(('field1', 'float', field1),))
|
|
new_cls = cl_.model_derivate('DerivedLLM', default_id='asdfasdf')
|
|
assert new_cls.__openllm_default_id__ == 'asdfasdf'
|
|
|
|
|
|
@given(model_settings())
|
|
def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings):
|
|
cl_ = make_llm_config('AttrsProtocolLLM', gen_settings)
|
|
assert attr.has(cl_)
|
|
|
|
|
|
@given(
|
|
model_settings(),
|
|
st.integers(max_value=283473),
|
|
st.floats(min_value=0.0, max_value=1.0),
|
|
st.integers(max_value=283473),
|
|
st.floats(min_value=0.0, max_value=1.0),
|
|
)
|
|
def test_complex_struct_dump(
|
|
gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float
|
|
):
|
|
cl_ = make_llm_config(
|
|
'ComplexLLM',
|
|
gen_settings,
|
|
fields=(('field1', 'float', field1),),
|
|
generation_fields=(('temperature', temperature),),
|
|
)
|
|
sent = cl_()
|
|
assert sent.model_dump()['field1'] == field1
|
|
assert sent.model_dump()['generation_config']['temperature'] == temperature
|
|
assert sent.model_dump(flatten=True)['field1'] == field1
|
|
assert sent.model_dump(flatten=True)['temperature'] == temperature
|
|
|
|
passed = cl_(field1=input_field1, temperature=input_temperature)
|
|
assert passed.model_dump()['field1'] == input_field1
|
|
assert passed.model_dump()['generation_config']['temperature'] == input_temperature
|
|
assert passed.model_dump(flatten=True)['field1'] == input_field1
|
|
assert passed.model_dump(flatten=True)['temperature'] == input_temperature
|
|
|
|
pas_nested = cl_(generation_config={'temperature': input_temperature}, field1=input_field1)
|
|
assert pas_nested.model_dump()['field1'] == input_field1
|
|
assert pas_nested.model_dump()['generation_config']['temperature'] == input_temperature
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def patch_env(**attrs: t.Any):
|
|
with mock.patch.dict(os.environ, attrs, clear=True):
|
|
yield
|
|
|
|
|
|
def test_struct_envvar():
|
|
with patch_env(**{field_env_key('field1'): '4', field_env_key('temperature', suffix='generation'): '0.2'}):
|
|
|
|
class EnvLLM(openllm.LLMConfig):
|
|
__config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'}
|
|
field1: int = 2
|
|
|
|
class GenerationConfig:
|
|
temperature: float = 0.8
|
|
|
|
sent = EnvLLM.model_construct_env()
|
|
assert sent.field1 == 4
|
|
assert sent['temperature'] == 0.2
|
|
|
|
overwrite_default = EnvLLM()
|
|
assert overwrite_default.field1 == 4
|
|
assert overwrite_default['temperature'] == 0.2
|
|
|
|
|
|
def test_struct_provided_fields():
|
|
class EnvLLM(openllm.LLMConfig):
|
|
__config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'}
|
|
field1: int = 2
|
|
|
|
class GenerationConfig:
|
|
temperature: float = 0.8
|
|
|
|
sent = EnvLLM.model_construct_env(field1=20, temperature=0.4)
|
|
assert sent.field1 == 20
|
|
assert sent.generation_config.temperature == 0.4
|
|
|
|
|
|
def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPatch):
|
|
with monkeypatch.context() as mk:
|
|
mk.setenv(field_env_key('field1'), str(4.0))
|
|
mk.setenv(field_env_key('temperature', suffix='generation'), str(0.2))
|
|
sent = make_llm_config(
|
|
'OverwriteWithEnvAvailable',
|
|
{'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'},
|
|
fields=(('field1', 'float', 3.0),),
|
|
).model_construct_env(field1=20.0, temperature=0.4)
|
|
assert sent.generation_config.temperature == 0.4
|
|
assert sent.field1 == 20.0
|
|
|
|
|
|
@pytest.mark.parametrize('model_name', openllm.CONFIG_MAPPING.keys())
|
|
def test_configuration_dict_protocol(model_name: str):
|
|
config = openllm.AutoConfig.for_model(model_name)
|
|
assert isinstance(config.items(), list)
|
|
assert isinstance(config.keys(), list)
|
|
assert isinstance(config.values(), list)
|
|
assert isinstance(dict(config), dict)
|