mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-23 23:07:47 -05:00
feat(metadata): add configuration to metadata endpoint
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -49,7 +49,7 @@ _import_structure = {
|
||||
"_configuration": ["LLMConfig"],
|
||||
"_package": ["build"],
|
||||
"exceptions": [],
|
||||
"_schema": ["GenerationInput", "GenerationOutput"],
|
||||
"_schema": ["GenerationInput", "GenerationOutput", "MetadataOutput"],
|
||||
"utils": [],
|
||||
"models": [],
|
||||
"client": [],
|
||||
@@ -149,6 +149,7 @@ if t.TYPE_CHECKING:
|
||||
from ._package import build as build
|
||||
from ._schema import GenerationInput as GenerationInput
|
||||
from ._schema import GenerationOutput as GenerationOutput
|
||||
from ._schema import MetadataOutput as MetadataOutput
|
||||
from .cli import start as start
|
||||
from .cli import start_grpc as start_grpc
|
||||
from .models.auto import CONFIG_MAPPING as CONFIG_MAPPING
|
||||
|
||||
@@ -57,10 +57,10 @@ import typing as t
|
||||
from operator import itemgetter
|
||||
|
||||
import attr
|
||||
import click_option_group as cog
|
||||
import inflection
|
||||
import orjson
|
||||
from cattr.gen import make_dict_unstructure_fn, override
|
||||
from click_option_group import optgroup
|
||||
from deepmerge.merger import Merger
|
||||
|
||||
import openllm
|
||||
@@ -123,65 +123,6 @@ config_merger = Merger(
|
||||
)
|
||||
|
||||
|
||||
@t.overload
|
||||
def attrs_to_options(
|
||||
name: str,
|
||||
field: attr.Attribute[t.Any],
|
||||
model_name: str,
|
||||
typ: type[t.Any] | None = None,
|
||||
suffix_generation: bool = False,
|
||||
) -> F[..., F[..., openllm.LLMConfig]]:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def attrs_to_options( # type: ignore (overlapping overload)
|
||||
name: str,
|
||||
field: attr.Attribute[O_co],
|
||||
model_name: str,
|
||||
typ: type[t.Any] | None = None,
|
||||
suffix_generation: bool = False,
|
||||
) -> F[..., F[P, O_co]]:
|
||||
...
|
||||
|
||||
|
||||
def attrs_to_options(
|
||||
name: str,
|
||||
field: attr.Attribute[t.Any],
|
||||
model_name: str,
|
||||
typ: type[t.Any] | None = None,
|
||||
suffix_generation: bool = False,
|
||||
) -> t.Callable[..., ClickFunctionWrapper[..., t.Any]]:
|
||||
# TODO: support parsing nested attrs class and Union
|
||||
envvar = field.metadata["env"]
|
||||
dasherized = inflection.dasherize(name)
|
||||
underscored = inflection.underscore(name)
|
||||
|
||||
if typ in (None, attr.NOTHING):
|
||||
typ = field.type
|
||||
|
||||
full_option_name = f"--{dasherized}"
|
||||
if field.type is bool:
|
||||
full_option_name += f"/--no-{dasherized}"
|
||||
if suffix_generation:
|
||||
identifier = f"{model_name}_generation_{underscored}"
|
||||
else:
|
||||
identifier = f"{model_name}_{underscored}"
|
||||
|
||||
return optgroup.option(
|
||||
identifier,
|
||||
full_option_name,
|
||||
type=dantic.parse_type(typ),
|
||||
required=field.default is attr.NOTHING,
|
||||
default=field.default if field.default not in (attr.NOTHING, None) else None,
|
||||
show_default=True,
|
||||
multiple=dantic.allows_multiple(typ),
|
||||
help=field.metadata.get("description", "(No description provided)"),
|
||||
show_envvar=True,
|
||||
envvar=envvar,
|
||||
)
|
||||
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class GenerationConfig:
|
||||
"""Generation config provides the configuration to then be parsed to ``transformers.GenerationConfig``,
|
||||
@@ -1215,8 +1156,8 @@ class LLMConfig:
|
||||
if t.get_origin(ty) is t.Union:
|
||||
# NOTE: Union type is currently not yet supported, we probably just need to use environment instead.
|
||||
continue
|
||||
f = attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty, suffix_generation=True)(f)
|
||||
f = optgroup.group(f"{cls.__openllm_generation_class__.__name__} generation options")(f)
|
||||
f = dantic.attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty, suffix_generation=True)(f)
|
||||
f = cog.optgroup.group(f"{cls.__openllm_generation_class__.__name__} generation options")(f)
|
||||
|
||||
if len(cls.__openllm_accepted_keys__.difference(set(attr.fields_dict(cls.__openllm_generation_class__)))) == 0:
|
||||
# NOTE: in this case, the function is already a ClickFunctionWrapper
|
||||
@@ -1230,9 +1171,9 @@ class LLMConfig:
|
||||
if t.get_origin(ty) is t.Union or name == "generation_config":
|
||||
# NOTE: Union type is currently not yet supported, we probably just need to use environment instead.
|
||||
continue
|
||||
f = attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty)(f)
|
||||
f = dantic.attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty)(f)
|
||||
|
||||
return optgroup.group(f"{cls.__name__} options")(f)
|
||||
return cog.optgroup.group(f"{cls.__name__} options")(f)
|
||||
|
||||
|
||||
bentoml_cattr.register_unstructure_hook_factory(
|
||||
|
||||
@@ -77,3 +77,12 @@ class GenerationOutput:
|
||||
|
||||
configuration: t.Dict[str, t.Any]
|
||||
"""A mapping of configuration values for given system."""
|
||||
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class MetadataOutput:
|
||||
model_id: str
|
||||
timeout: int
|
||||
model_name: str
|
||||
framework: str
|
||||
configuration: str
|
||||
|
||||
@@ -52,10 +52,11 @@ async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
|
||||
|
||||
@svc.api(input=bentoml.io.Text(), output=bentoml.io.JSON(), route="/v1/metadata")
|
||||
def metadata_v1(_: str) -> dict[str, t.Any]:
|
||||
return {
|
||||
"model_id": model_id,
|
||||
"timeout": llm_config.__openllm_timeout__,
|
||||
"model_name": llm_config.__openllm_model_name__,
|
||||
"framework": llm_config.__openllm_env__.get_framework_env(),
|
||||
}
|
||||
def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
return openllm.MetadataOutput(
|
||||
model_id=model_id,
|
||||
timeout=llm_config.__openllm_timeout__,
|
||||
model_name=llm_config.__openllm_model_name__,
|
||||
framework=llm_config.__openllm_env__.get_framework_env(),
|
||||
configuration=llm_config.model_dump_json().decode(),
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ import inflection
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
ConfigOrderedDict = OrderedDict[str, openllm.LLMConfig]
|
||||
ConfigOrderedDict = OrderedDict[str, type[openllm.LLMConfig]]
|
||||
else:
|
||||
ConfigOrderedDict = OrderedDict
|
||||
|
||||
@@ -86,7 +86,7 @@ class _LazyConfigMapping(ConfigOrderedDict):
|
||||
self._extra_content[key] = value
|
||||
|
||||
|
||||
CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
||||
CONFIG_MAPPING: dict[str, type[openllm.LLMConfig]] = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
||||
CONFIG_NAME_ALIASES: dict[str, str] = {"chat_glm": "chatglm", "stable_lm": "stablelm", "star_coder": "starcoder"}
|
||||
|
||||
|
||||
|
||||
@@ -14,11 +14,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import _BaseAutoLLMClass, _LazyAutoMapping
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
|
||||
import openllm
|
||||
|
||||
MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("flan_t5", "FlanT5"),
|
||||
@@ -30,7 +36,9 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
MODEL_MAPPING: dict[
|
||||
type[openllm.LLMConfig], type[openllm.LLM[transformers.PreTrainedModel, transformers.PreTrainedTokenizerFast]]
|
||||
] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
|
||||
|
||||
class AutoLLM(_BaseAutoLLMClass):
|
||||
|
||||
@@ -14,14 +14,22 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import _BaseAutoLLMClass, _LazyAutoMapping
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
|
||||
import openllm
|
||||
|
||||
MODEL_FLAX_MAPPING_NAMES = OrderedDict([("flan_t5", "FlaxFlanT5")])
|
||||
|
||||
MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES)
|
||||
MODEL_FLAX_MAPPING: dict[
|
||||
type[openllm.LLMConfig], type[openllm.LLM[transformers.FlaxPreTrainedModel, transformers.PreTrainedTokenizerFast]]
|
||||
] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES)
|
||||
|
||||
|
||||
class AutoFlaxLLM(_BaseAutoLLMClass):
|
||||
|
||||
@@ -14,14 +14,22 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as t
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
from .factory import _BaseAutoLLMClass, _LazyAutoMapping
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
|
||||
import openllm
|
||||
|
||||
MODEL_TF_MAPPING_NAMES = OrderedDict([("flan_t5", "TFFlanT5")])
|
||||
|
||||
MODEL_TF_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES)
|
||||
MODEL_TF_MAPPING: dict[
|
||||
type[openllm.LLMConfig], type[openllm.LLM[transformers.TFPreTrainedModel, transformers.PreTrainedTokenizerFast]]
|
||||
] = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES)
|
||||
|
||||
|
||||
class AutoTFLLM(_BaseAutoLLMClass):
|
||||
|
||||
@@ -23,17 +23,80 @@ from enum import Enum
|
||||
|
||||
import attr
|
||||
import click
|
||||
import click_option_group as cog
|
||||
import inflection
|
||||
import orjson
|
||||
from click import ParamType
|
||||
|
||||
import openllm
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import _ValidatorType # type: ignore
|
||||
from attr import _ValidatorType
|
||||
|
||||
from .._types import ClickFunctionWrapper, F, O_co, P
|
||||
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
|
||||
@t.overload
|
||||
def attrs_to_options(
|
||||
name: str,
|
||||
field: attr.Attribute[t.Any],
|
||||
model_name: str,
|
||||
typ: type[t.Any] | None = None,
|
||||
suffix_generation: bool = False,
|
||||
) -> F[..., F[..., openllm.LLMConfig]]:
|
||||
...
|
||||
|
||||
|
||||
@t.overload
|
||||
def attrs_to_options( # type: ignore (overlapping overload)
|
||||
name: str,
|
||||
field: attr.Attribute[O_co],
|
||||
model_name: str,
|
||||
typ: type[t.Any] | None = None,
|
||||
suffix_generation: bool = False,
|
||||
) -> F[..., F[P, O_co]]:
|
||||
...
|
||||
|
||||
|
||||
def attrs_to_options(
|
||||
name: str,
|
||||
field: attr.Attribute[t.Any],
|
||||
model_name: str,
|
||||
typ: type[t.Any] | None = None,
|
||||
suffix_generation: bool = False,
|
||||
) -> t.Callable[..., ClickFunctionWrapper[..., t.Any]]:
|
||||
# TODO: support parsing nested attrs class and Union
|
||||
envvar = field.metadata["env"]
|
||||
dasherized = inflection.dasherize(name)
|
||||
underscored = inflection.underscore(name)
|
||||
|
||||
if typ in (None, attr.NOTHING):
|
||||
typ = field.type
|
||||
|
||||
full_option_name = f"--{dasherized}"
|
||||
if field.type is bool:
|
||||
full_option_name += f"/--no-{dasherized}"
|
||||
if suffix_generation:
|
||||
identifier = f"{model_name}_generation_{underscored}"
|
||||
else:
|
||||
identifier = f"{model_name}_{underscored}"
|
||||
|
||||
return cog.optgroup.option(
|
||||
identifier,
|
||||
full_option_name,
|
||||
type=parse_type(typ),
|
||||
required=field.default is attr.NOTHING,
|
||||
default=field.default if field.default not in (attr.NOTHING, None) else None,
|
||||
show_default=True,
|
||||
multiple=allows_multiple(typ),
|
||||
help=field.metadata.get("description", "(No description provided)"),
|
||||
show_envvar=True,
|
||||
envvar=envvar,
|
||||
)
|
||||
|
||||
|
||||
def _default_converter(value: t.Any, env: str | None) -> t.Any:
|
||||
if env is not None:
|
||||
value = os.environ.get(env, value)
|
||||
@@ -117,7 +180,7 @@ def Field(
|
||||
return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs)
|
||||
|
||||
|
||||
def parse_type(field_type: type) -> ParamType:
|
||||
def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType]:
|
||||
"""Transforms the pydantic field's type into a click-compatible type.
|
||||
|
||||
Args:
|
||||
@@ -305,7 +368,7 @@ def is_container(field_type: type) -> bool:
|
||||
return openllm.utils.lenient_issubclass(origin, t.Container)
|
||||
|
||||
|
||||
def parse_container_args(field_type: type) -> ParamType | tuple[ParamType]:
|
||||
def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType]:
|
||||
"""Parses the arguments inside a container type (lists, tuples and so on).
|
||||
|
||||
Args:
|
||||
|
||||
@@ -83,6 +83,11 @@ class ClientMixin:
|
||||
def model_id(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def llm(self) -> openllm.LLM[t.Any, t.Any]:
|
||||
if self.__llm__ is None:
|
||||
|
||||
@@ -18,6 +18,8 @@ import asyncio
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
import orjson
|
||||
|
||||
import openllm
|
||||
|
||||
from .base import BaseAsyncClient, BaseClient
|
||||
@@ -63,6 +65,14 @@ class GrpcClientMixin:
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
|
||||
@property
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
try:
|
||||
v = self._metadata.json.struct_value.fields["configuration"].string_value
|
||||
return orjson.loads(v)
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
|
||||
def postprocess(self, result: Response) -> openllm.GenerationOutput:
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ import logging
|
||||
import typing as t
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import orjson
|
||||
|
||||
import openllm
|
||||
|
||||
from .base import BaseAsyncClient, BaseClient
|
||||
@@ -56,6 +58,13 @@ class HTTPClientMixin:
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
|
||||
@property
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
try:
|
||||
return orjson.loads(self._metadata["configuration"])
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)")
|
||||
|
||||
def postprocess(self, result: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
return openllm.GenerationOutput(**result)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user