from __future__ import annotations import contextlib import os import typing as t from unittest import mock import attr import pytest from hypothesis import assume, given, strategies as st import openllm from openllm_core._configuration import GenerationConfig, ModelSettings, field_env_key from ._strategies._configuration import make_llm_config, model_settings def test_forbidden_access(): cl_ = make_llm_config( 'ForbiddenAccess', { 'default_id': 'huggingface/t5-tiny-testing', 'model_ids': ['huggingface/t5-tiny-testing', 'bentoml/t5-tiny-testing'], 'architecture': 'PreTrainedModel', 'requirements': ['bentoml'], }, ) assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), '__config__') assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), 'GenerationConfig') assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), 'SamplingParams') assert openllm.utils.lenient_issubclass(cl_.__openllm_generation_class__, GenerationConfig) @given(model_settings()) def test_class_normal_gen(gen_settings: ModelSettings): assume(gen_settings['default_id'] and all(i for i in gen_settings['model_ids'])) cl_: type[openllm.LLMConfig] = make_llm_config('NotFullLLM', gen_settings) assert issubclass(cl_, openllm.LLMConfig) for key in gen_settings: assert object.__getattribute__(cl_, f'__openllm_{key}__') == gen_settings.__getitem__(key) @given(model_settings(), st.integers()) def test_simple_struct_dump(gen_settings: ModelSettings, field1: int): cl_ = make_llm_config('IdempotentLLM', gen_settings, fields=(('field1', 'float', field1),)) assert cl_().model_dump()['field1'] == field1 @given(model_settings(), st.integers()) def test_config_derivation(gen_settings: ModelSettings, field1: int): cl_ = make_llm_config('IdempotentLLM', gen_settings, fields=(('field1', 'float', field1),)) new_cls = cl_.model_derivate('DerivedLLM', default_id='asdfasdf') assert new_cls.__openllm_default_id__ == 'asdfasdf' @given(model_settings()) def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings): cl_ = make_llm_config('AttrsProtocolLLM', gen_settings) assert attr.has(cl_) @given( model_settings(), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0), ) def test_complex_struct_dump( gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float ): cl_ = make_llm_config( 'ComplexLLM', gen_settings, fields=(('field1', 'float', field1),), generation_fields=(('temperature', temperature),), ) sent = cl_() assert sent.model_dump()['field1'] == field1 assert sent.model_dump()['generation_config']['temperature'] == temperature assert sent.model_dump(flatten=True)['field1'] == field1 assert sent.model_dump(flatten=True)['temperature'] == temperature passed = cl_(field1=input_field1, temperature=input_temperature) assert passed.model_dump()['field1'] == input_field1 assert passed.model_dump()['generation_config']['temperature'] == input_temperature assert passed.model_dump(flatten=True)['field1'] == input_field1 assert passed.model_dump(flatten=True)['temperature'] == input_temperature pas_nested = cl_(generation_config={'temperature': input_temperature}, field1=input_field1) assert pas_nested.model_dump()['field1'] == input_field1 assert pas_nested.model_dump()['generation_config']['temperature'] == input_temperature @contextlib.contextmanager def patch_env(**attrs: t.Any): with mock.patch.dict(os.environ, attrs, clear=True): yield def test_struct_envvar(): with patch_env(**{field_env_key('field1'): '4', field_env_key('temperature', suffix='generation'): '0.2'}): class EnvLLM(openllm.LLMConfig): __config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'} field1: int = 2 class GenerationConfig: temperature: float = 0.8 sent = EnvLLM.model_construct_env() assert sent.field1 == 4 assert sent['temperature'] == 0.2 overwrite_default = EnvLLM() assert overwrite_default.field1 == 4 assert overwrite_default['temperature'] == 0.2 def test_struct_provided_fields(): class EnvLLM(openllm.LLMConfig): __config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'} field1: int = 2 class GenerationConfig: temperature: float = 0.8 sent = EnvLLM.model_construct_env(field1=20, temperature=0.4) assert sent.field1 == 20 assert sent.generation_config.temperature == 0.4 def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as mk: mk.setenv(field_env_key('field1'), str(4.0)) mk.setenv(field_env_key('temperature', suffix='generation'), str(0.2)) sent = make_llm_config( 'OverwriteWithEnvAvailable', {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'}, fields=(('field1', 'float', 3.0),), ).model_construct_env(field1=20.0, temperature=0.4) assert sent.generation_config.temperature == 0.4 assert sent.field1 == 20.0 @pytest.mark.parametrize('model_name', openllm.CONFIG_MAPPING.keys()) def test_configuration_dict_protocol(model_name: str): config = openllm.AutoConfig.for_model(model_name) assert isinstance(config.items(), list) assert isinstance(config.keys(), list) assert isinstance(config.values(), list) assert isinstance(dict(config), dict)