from __future__ import annotations import logging import typing as t from hypothesis import strategies as st import openllm from openllm_core._configuration import ModelSettings logger = logging.getLogger(__name__) env_strats = st.sampled_from([openllm.utils.EnvVarMixin(model_name) for model_name in openllm.CONFIG_MAPPING.keys()]) @st.composite def model_settings(draw: st.DrawFn): '''Strategy for generating ModelSettings objects.''' kwargs: dict[str, t.Any] = { 'default_id': st.text(min_size=1), 'model_ids': st.lists(st.text(), min_size=1), 'architecture': st.text(min_size=1), 'url': st.text(), 'trust_remote_code': st.booleans(), 'requirements': st.none() | st.lists(st.text(), min_size=1), 'default_backend': st.dictionaries(st.sampled_from(['cpu', 'nvidia.com/gpu']), st.sampled_from(['vllm', 'pt', 'tf', 'flax'])), 'model_type': st.sampled_from(['causal_lm', 'seq2seq_lm']), 'name_type': st.sampled_from(['dasherize', 'lowercase']), 'timeout': st.integers(min_value=3600), 'workers_per_resource': st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)), } return draw(st.builds(ModelSettings, **kwargs)) def make_llm_config(cls_name: str, dunder_config: dict[str, t.Any] | ModelSettings, fields: tuple[tuple[t.LiteralString, str, t.Any], ...] | None = None, generation_fields: tuple[tuple[t.LiteralString, t.Any], ...] | None = None, ) -> type[openllm.LLMConfig]: globs: dict[str, t.Any] = {'openllm': openllm} _config_args: list[str] = [] lines: list[str] = [f'class {cls_name}Config(openllm.LLMConfig):'] for attr, value in dunder_config.items(): _config_args.append(f'"{attr}": __attr_{attr}') globs[f'_{cls_name}Config__attr_{attr}'] = value lines.append(f' __config__ = {{ {", ".join(_config_args)} }}') if fields is not None: for field, type_, default in fields: lines.append(f' {field}: {type_} = openllm.LLMConfig.Field({default!r})') if generation_fields is not None: generation_lines = ['class GenerationConfig:'] for field, default in generation_fields: generation_lines.append(f' {field} = {default!r}') lines.extend((' ' + line for line in generation_lines)) script = '\n'.join(lines) if openllm.utils.DEBUG: logger.info('Generated class %s:\n%s', cls_name, script) eval(compile(script, 'name', 'exec'), globs) return globs[f'{cls_name}Config']