Files
OpenLLM/openllm-python/tests/_strategies/_configuration.py
2023-08-15 02:11:14 -04:00

50 lines
2.4 KiB
Python

from __future__ import annotations
import logging
import typing as t
from hypothesis import strategies as st
import openllm
from openllm._configuration import ModelSettings
logger = logging.getLogger(__name__)
env_strats = st.sampled_from([openllm.utils.EnvVarMixin(model_name) for model_name in openllm.CONFIG_MAPPING.keys()])
@st.composite
def model_settings(draw: st.DrawFn):
"""Strategy for generating ModelSettings objects."""
kwargs: dict[str, t.Any] = {
"default_id": st.text(min_size=1), "model_ids": st.lists(st.text(), min_size=1), "architecture": st.text(min_size=1), "url": st.text(), "requires_gpu": st.booleans(), "trust_remote_code": st.booleans(), "requirements": st.none()
| st.lists(st.text(), min_size=1), "default_implementation": st.dictionaries(st.sampled_from(["cpu", "nvidia.com/gpu"]), st.sampled_from(["vllm", "pt", "tf", "flax"])), "model_type": st.sampled_from(["causal_lm", "seq2seq_lm"]), "runtime": st.sampled_from(["transformers", "ggml"]), "name_type": st.sampled_from(["dasherize", "lowercase"]), "timeout": st.integers(
min_value=3600
), "workers_per_resource": st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)),
}
return draw(st.builds(ModelSettings, **kwargs))
def make_llm_config(cls_name: str, dunder_config: dict[str, t.Any] | ModelSettings, fields: tuple[tuple[t.LiteralString, str, t.Any], ...] | None = None, generation_fields: tuple[tuple[t.LiteralString, t.Any], ...] | None = None,) -> type[openllm.LLMConfig]:
globs: dict[str, t.Any] = {"openllm": openllm}
_config_args: list[str] = []
lines: list[str] = [f"class {cls_name}Config(openllm.LLMConfig):"]
for attr, value in dunder_config.items():
_config_args.append(f'"{attr}": __attr_{attr}')
globs[f"_{cls_name}Config__attr_{attr}"] = value
lines.append(f' __config__ = {{ {", ".join(_config_args)} }}')
if fields is not None:
for field, type_, default in fields:
lines.append(f" {field}: {type_} = openllm.LLMConfig.Field({default!r})")
if generation_fields is not None:
generation_lines = ["class GenerationConfig:"]
for field, default in generation_fields:
generation_lines.append(f" {field} = {default!r}")
lines.extend((" " + line for line in generation_lines))
script = "\n".join(lines)
if openllm.utils.DEBUG:
logger.info("Generated class %s:\n%s", cls_name, script)
eval(compile(script, "name", "exec"), globs)
return globs[f"{cls_name}Config"]