From cb76a894cf83bb5ab15ea90f28330228a3df110e Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Tue, 13 Jun 2023 07:09:13 -0400 Subject: [PATCH] feat(metadata): add configuration to metadata endpoint Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- src/openllm/__init__.py | 3 +- src/openllm/_configuration.py | 69 ++----------------- src/openllm/_schema.py | 9 +++ src/openllm/_service.py | 15 ++-- src/openllm/models/auto/configuration_auto.py | 4 +- src/openllm/models/auto/modeling_auto.py | 10 ++- src/openllm/models/auto/modeling_flax_auto.py | 10 ++- src/openllm/models/auto/modeling_tf_auto.py | 10 ++- src/openllm/utils/dantic.py | 69 ++++++++++++++++++- src/openllm_client/runtimes/base.py | 5 ++ src/openllm_client/runtimes/grpc.py | 10 +++ src/openllm_client/runtimes/http.py | 9 +++ 12 files changed, 143 insertions(+), 80 deletions(-) diff --git a/src/openllm/__init__.py b/src/openllm/__init__.py index 18f257e1..c27ef3d9 100644 --- a/src/openllm/__init__.py +++ b/src/openllm/__init__.py @@ -49,7 +49,7 @@ _import_structure = { "_configuration": ["LLMConfig"], "_package": ["build"], "exceptions": [], - "_schema": ["GenerationInput", "GenerationOutput"], + "_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput"], "utils": [], "models": [], "client": [], @@ -149,6 +149,7 @@ if t.TYPE_CHECKING: from ._package import build as build from ._schema import GenerationInput as GenerationInput from ._schema import GenerationOutput as GenerationOutput + from ._schema import MetadataOutput as MetadataOutput from .cli import start as start from .cli import start_grpc as start_grpc from .models.auto import CONFIG_MAPPING as CONFIG_MAPPING diff --git a/src/openllm/_configuration.py b/src/openllm/_configuration.py index 56605d0a..c98ca833 100644 --- a/src/openllm/_configuration.py +++ b/src/openllm/_configuration.py @@ -57,10 +57,10 @@ import typing as t from operator import itemgetter import attr +import click_option_group as cog import inflection import orjson from cattr.gen import make_dict_unstructure_fn, override -from click_option_group import optgroup from deepmerge.merger import Merger import openllm @@ -123,65 +123,6 @@ config_merger = Merger( ) -@t.overload -def attrs_to_options( - name: str, - field: attr.Attribute[t.Any], - model_name: str, - typ: type[t.Any] | None = None, - suffix_generation: bool = False, -) -> F[..., F[..., openllm.LLMConfig]]: - ... - - -@t.overload -def attrs_to_options( # type: ignore (overlapping overload) - name: str, - field: attr.Attribute[O_co], - model_name: str, - typ: type[t.Any] | None = None, - suffix_generation: bool = False, -) -> F[..., F[P, O_co]]: - ... - - -def attrs_to_options( - name: str, - field: attr.Attribute[t.Any], - model_name: str, - typ: type[t.Any] | None = None, - suffix_generation: bool = False, -) -> t.Callable[..., ClickFunctionWrapper[..., t.Any]]: - # TODO: support parsing nested attrs class and Union - envvar = field.metadata["env"] - dasherized = inflection.dasherize(name) - underscored = inflection.underscore(name) - - if typ in (None, attr.NOTHING): - typ = field.type - - full_option_name = f"--{dasherized}" - if field.type is bool: - full_option_name += f"/--no-{dasherized}" - if suffix_generation: - identifier = f"{model_name}_generation_{underscored}" - else: - identifier = f"{model_name}_{underscored}" - - return optgroup.option( - identifier, - full_option_name, - type=dantic.parse_type(typ), - required=field.default is attr.NOTHING, - default=field.default if field.default not in (attr.NOTHING, None) else None, - show_default=True, - multiple=dantic.allows_multiple(typ), - help=field.metadata.get("description", "(No description provided)"), - show_envvar=True, - envvar=envvar, - ) - - @attr.frozen(slots=True) class GenerationConfig: """Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``, @@ -1215,8 +1156,8 @@ class LLMConfig: if t.get_origin(ty) is t.Union: # NOTE: Union type is currently not yet supported, we probably just need to use environment instead. continue - f = attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty, suffix_generation=True)(f) - f = optgroup.group(f"{cls.__openllm_generation_class__.__name__} generation options")(f) + f = dantic.attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty, suffix_generation=True)(f) + f = cog.optgroup.group(f"{cls.__openllm_generation_class__.__name__} generation options")(f) if len(cls.__openllm_accepted_keys__.difference(set(attr.fields_dict(cls.__openllm_generation_class__)))) == 0: # NOTE: in this case, the function is already a ClickFunctionWrapper @@ -1230,9 +1171,9 @@ class LLMConfig: if t.get_origin(ty) is t.Union or name == "generation_config": # NOTE: Union type is currently not yet supported, we probably just need to use environment instead. continue - f = attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty)(f) + f = dantic.attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty)(f) - return optgroup.group(f"{cls.__name__} options")(f) + return cog.optgroup.group(f"{cls.__name__} options")(f) bentoml_cattr.register_unstructure_hook_factory( diff --git a/src/openllm/_schema.py b/src/openllm/_schema.py index df45631a..096d42d6 100644 --- a/src/openllm/_schema.py +++ b/src/openllm/_schema.py @@ -77,3 +77,12 @@ class GenerationOutput: configuration: t.Dict[str, t.Any] """A mapping of configuration values for given system.""" + + +@attr.frozen(slots=True) +class MetadataOutput: + model_id: str + timeout: int + model_name: str + framework: str + configuration: str diff --git a/src/openllm/_service.py b/src/openllm/_service.py index 3d3768bf..58d07053 100644 --- a/src/openllm/_service.py +++ b/src/openllm/_service.py @@ -52,10 +52,11 @@ async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput: @svc.api(input=bentoml.io.Text(), output=bentoml.io.JSON(), route="/v1/metadata") -def metadata_v1(_: str) -> dict[str, t.Any]: - return { - "model_id": model_id, - "timeout": llm_config.__openllm_timeout__, - "model_name": llm_config.__openllm_model_name__, - "framework": llm_config.__openllm_env__.get_framework_env(), - } +def metadata_v1(_: str) -> openllm.MetadataOutput: + return openllm.MetadataOutput( + model_id=model_id, + timeout=llm_config.__openllm_timeout__, + model_name=llm_config.__openllm_model_name__, + framework=llm_config.__openllm_env__.get_framework_env(), + configuration=llm_config.model_dump_json().decode(), + ) diff --git a/src/openllm/models/auto/configuration_auto.py b/src/openllm/models/auto/configuration_auto.py index a36012d5..57d05600 100644 --- a/src/openllm/models/auto/configuration_auto.py +++ b/src/openllm/models/auto/configuration_auto.py @@ -23,7 +23,7 @@ import inflection import openllm if t.TYPE_CHECKING: - ConfigOrderedDict = OrderedDict[str, openllm.LLMConfig] + ConfigOrderedDict = OrderedDict[str, type[openllm.LLMConfig]] else: ConfigOrderedDict = OrderedDict @@ -86,7 +86,7 @@ class _LazyConfigMapping(ConfigOrderedDict): self._extra_content[key] = value -CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES) +CONFIG_MAPPING: dict[str, type[openllm.LLMConfig]] = _LazyConfigMapping(CONFIG_MAPPING_NAMES) CONFIG_NAME_ALIASES: dict[str, str] = {"chat_glm": "chatglm", "stable_lm": "stablelm", "star_coder": "starcoder"} diff --git a/src/openllm/models/auto/modeling_auto.py b/src/openllm/models/auto/modeling_auto.py index e25d1a94..b0ad9c6a 100644 --- a/src/openllm/models/auto/modeling_auto.py +++ b/src/openllm/models/auto/modeling_auto.py @@ -14,11 +14,17 @@ from __future__ import annotations +import typing as t from collections import OrderedDict from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import _BaseAutoLLMClass, _LazyAutoMapping +if t.TYPE_CHECKING: + import transformers + + import openllm + MODEL_MAPPING_NAMES = OrderedDict( [ ("flan_t5", "FlanT5"), @@ -30,7 +36,9 @@ MODEL_MAPPING_NAMES = OrderedDict( ] ) -MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) +MODEL_MAPPING: dict[ + type[openllm.LLMConfig], type[openllm.LLM[transformers.PreTrainedModel, transformers.PreTrainedTokenizerFast]] +] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) class AutoLLM(_BaseAutoLLMClass): diff --git a/src/openllm/models/auto/modeling_flax_auto.py b/src/openllm/models/auto/modeling_flax_auto.py index 2e4af02f..44bd6686 100644 --- a/src/openllm/models/auto/modeling_flax_auto.py +++ b/src/openllm/models/auto/modeling_flax_auto.py @@ -14,14 +14,22 @@ from __future__ import annotations +import typing as t from collections import OrderedDict from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import _BaseAutoLLMClass, _LazyAutoMapping +if t.TYPE_CHECKING: + import transformers + + import openllm + MODEL_FLAX_MAPPING_NAMES = OrderedDict([("flan_t5", "FlaxFlanT5")]) -MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES) +MODEL_FLAX_MAPPING: dict[ + type[openllm.LLMConfig], type[openllm.LLM[transformers.FlaxPreTrainedModel, transformers.PreTrainedTokenizerFast]] +] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES) class AutoFlaxLLM(_BaseAutoLLMClass): diff --git a/src/openllm/models/auto/modeling_tf_auto.py b/src/openllm/models/auto/modeling_tf_auto.py index 052df96b..1f35d0e4 100644 --- a/src/openllm/models/auto/modeling_tf_auto.py +++ b/src/openllm/models/auto/modeling_tf_auto.py @@ -14,14 +14,22 @@ from __future__ import annotations +import typing as t from collections import OrderedDict from .configuration_auto import CONFIG_MAPPING_NAMES from .factory import _BaseAutoLLMClass, _LazyAutoMapping +if t.TYPE_CHECKING: + import transformers + + import openllm + MODEL_TF_MAPPING_NAMES = OrderedDict([("flan_t5", "TFFlanT5")]) -MODEL_TF_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES) +MODEL_TF_MAPPING: dict[ + type[openllm.LLMConfig], type[openllm.LLM[transformers.TFPreTrainedModel, transformers.PreTrainedTokenizerFast]] +] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES) class AutoTFLLM(_BaseAutoLLMClass): diff --git a/src/openllm/utils/dantic.py b/src/openllm/utils/dantic.py index 8b62391e..74f0bff0 100644 --- a/src/openllm/utils/dantic.py +++ b/src/openllm/utils/dantic.py @@ -23,17 +23,80 @@ from enum import Enum import attr import click +import click_option_group as cog +import inflection import orjson from click import ParamType import openllm if t.TYPE_CHECKING: - from attr import _ValidatorType # type: ignore + from attr import _ValidatorType + + from .._types import ClickFunctionWrapper, F, O_co, P _T = t.TypeVar("_T") +@t.overload +def attrs_to_options( + name: str, + field: attr.Attribute[t.Any], + model_name: str, + typ: type[t.Any] | None = None, + suffix_generation: bool = False, +) -> F[..., F[..., openllm.LLMConfig]]: + ... + + +@t.overload +def attrs_to_options( # type: ignore (overlapping overload) + name: str, + field: attr.Attribute[O_co], + model_name: str, + typ: type[t.Any] | None = None, + suffix_generation: bool = False, +) -> F[..., F[P, O_co]]: + ... + + +def attrs_to_options( + name: str, + field: attr.Attribute[t.Any], + model_name: str, + typ: type[t.Any] | None = None, + suffix_generation: bool = False, +) -> t.Callable[..., ClickFunctionWrapper[..., t.Any]]: + # TODO: support parsing nested attrs class and Union + envvar = field.metadata["env"] + dasherized = inflection.dasherize(name) + underscored = inflection.underscore(name) + + if typ in (None, attr.NOTHING): + typ = field.type + + full_option_name = f"--{dasherized}" + if field.type is bool: + full_option_name += f"/--no-{dasherized}" + if suffix_generation: + identifier = f"{model_name}_generation_{underscored}" + else: + identifier = f"{model_name}_{underscored}" + + return cog.optgroup.option( + identifier, + full_option_name, + type=parse_type(typ), + required=field.default is attr.NOTHING, + default=field.default if field.default not in (attr.NOTHING, None) else None, + show_default=True, + multiple=allows_multiple(typ), + help=field.metadata.get("description", "(No description provided)"), + show_envvar=True, + envvar=envvar, + ) + + def _default_converter(value: t.Any, env: str | None) -> t.Any: if env is not None: value = os.environ.get(env, value) @@ -117,7 +180,7 @@ def Field( return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs) -def parse_type(field_type: type) -> ParamType: +def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType]: """Transforms the pydantic field's type into a click-compatible type. Args: @@ -305,7 +368,7 @@ def is_container(field_type: type) -> bool: return openllm.utils.lenient_issubclass(origin, t.Container) -def parse_container_args(field_type: type) -> ParamType | tuple[ParamType]: +def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType]: """Parses the arguments inside a container type (lists, tuples and so on). Args: diff --git a/src/openllm_client/runtimes/base.py b/src/openllm_client/runtimes/base.py index 693bc306..77f3aa3b 100644 --- a/src/openllm_client/runtimes/base.py +++ b/src/openllm_client/runtimes/base.py @@ -83,6 +83,11 @@ class ClientMixin: def model_id(self) -> str: raise NotImplementedError + @property + @abstractmethod + def configuration(self) -> dict[str, t.Any]: + raise NotImplementedError + @property def llm(self) -> openllm.LLM[t.Any, t.Any]: if self.__llm__ is None: diff --git a/src/openllm_client/runtimes/grpc.py b/src/openllm_client/runtimes/grpc.py index 12a27f59..dfea9356 100644 --- a/src/openllm_client/runtimes/grpc.py +++ b/src/openllm_client/runtimes/grpc.py @@ -18,6 +18,8 @@ import asyncio import logging import typing as t +import orjson + import openllm from .base import BaseAsyncClient, BaseClient @@ -63,6 +65,14 @@ class GrpcClientMixin: except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") + @property + def configuration(self) -> dict[str, t.Any]: + try: + v = self._metadata.json.struct_value.fields["configuration"].string_value + return orjson.loads(v) + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") + def postprocess(self, result: Response) -> openllm.GenerationOutput: from google.protobuf.json_format import MessageToDict diff --git a/src/openllm_client/runtimes/http.py b/src/openllm_client/runtimes/http.py index ff35202c..36c79477 100644 --- a/src/openllm_client/runtimes/http.py +++ b/src/openllm_client/runtimes/http.py @@ -18,6 +18,8 @@ import logging import typing as t from urllib.parse import urlparse +import orjson + import openllm from .base import BaseAsyncClient, BaseClient @@ -56,6 +58,13 @@ class HTTPClientMixin: except KeyError: raise RuntimeError("Malformed service endpoint. (Possible malicious)") + @property + def configuration(self) -> dict[str, t.Any]: + try: + return orjson.loads(self._metadata["configuration"]) + except KeyError: + raise RuntimeError("Malformed service endpoint. (Possible malicious)") + def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput: return openllm.GenerationOutput(**result)