types: update stubs for remaining entrypoints (#667)

* perf(type): static OpenAI types definition

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* feat: add hf types

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* types: update remaining missing stubs

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-16 04:26:13 -05:00
committed by GitHub
parent 6102a67a83
commit 9e3f0fea15
13 changed files with 231 additions and 260 deletions

View File

@@ -1,21 +1,14 @@
from __future__ import annotations
import functools
import inspect
import types
import typing as t
import attr
from starlette.routing import BaseRoute, Host, Mount, Route
from starlette.routing import Host, Mount, Route
from starlette.schemas import EndpointInfo, SchemaGenerator
from openllm_core._typing_compat import ParamSpec
from openllm_core.utils import first_not_none
if t.TYPE_CHECKING:
from attr import AttrsInstance
import bentoml
P = ParamSpec('P')
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
# NOTE: OpenAI schema
LIST_MODELS_SCHEMA = """\
@@ -479,7 +472,7 @@ summary: Creates a model response for the given chat conversation.
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}
def add_schema_definitions(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
def add_schema_definitions(func):
append_str = _SCHEMAS.get(func.__name__.lower(), '')
if not append_str:
return func
@@ -490,8 +483,8 @@ def add_schema_definitions(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
class OpenLLMSchemaGenerator(SchemaGenerator):
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
endpoints_info: list[EndpointInfo] = []
def get_endpoints(self, routes):
endpoints_info = []
for route in routes:
if isinstance(route, (Mount, Host)):
routes = route.routes or []
@@ -523,7 +516,7 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
endpoints_info.append(EndpointInfo(path, method.lower(), func))
return endpoints_info
def get_schema(self, routes: list[BaseRoute], mount_path: str | None = None) -> dict[str, t.Any]:
def get_schema(self, routes, mount_path=None):
schema = dict(self.base_schema)
schema.setdefault('paths', {})
endpoints_info = self.get_endpoints(routes)
@@ -543,13 +536,8 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
return schema
def get_generator(
title: str,
components: list[type[AttrsInstance]] | None = None,
tags: list[dict[str, t.Any]] | None = None,
inject: bool = True,
) -> OpenLLMSchemaGenerator:
base_schema: dict[str, t.Any] = dict(info={'title': title, 'version': API_VERSION}, version=OPENAPI_VERSION)
def get_generator(title, components=None, tags=None, inject=True):
base_schema = {'info': {'title': title, 'version': API_VERSION}, 'version': OPENAPI_VERSION}
if components and inject:
base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
if tags is not None and tags and inject:
@@ -557,12 +545,12 @@ def get_generator(
return OpenLLMSchemaGenerator(base_schema)
def component_schema_generator(attr_cls: type[AttrsInstance], description: str | None = None) -> dict[str, t.Any]:
schema: dict[str, t.Any] = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
def component_schema_generator(attr_cls, description=None):
schema = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
schema['description'] = first_not_none(
getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}'
)
for field in attr.fields(attr.resolve_types(attr_cls)): # type: ignore[misc,type-var]
for field in attr.fields(attr.resolve_types(attr_cls)):
attr_type = field.type
origin_type = t.get_origin(attr_type)
args_type = t.get_args(attr_type)
@@ -593,7 +581,7 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str |
if 'prop_schema' not in locals():
prop_schema = {'type': schema_type}
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory):
prop_schema['default'] = field.default # type: ignore[arg-type]
prop_schema['default'] = field.default
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)):
schema['required'].append(field.name)
schema['properties'][field.name] = prop_schema
@@ -602,20 +590,15 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str |
return schema
class MKSchema:
def __init__(self, it: dict[str, t.Any]) -> None:
self.it = it
def asdict(self) -> dict[str, t.Any]:
return self.it
_SimpleSchema = types.new_class(
'_SimpleSchema',
(object,),
{},
lambda ns: ns.update({'__init__': lambda self, it: setattr(self, 'it', it), 'asdict': lambda self: self.it}),
)
def append_schemas(
svc: bentoml.Service,
generated_schema: dict[str, t.Any],
tags_order: t.Literal['prepend', 'append'] = 'prepend',
inject: bool = True,
) -> bentoml.Service:
def append_schemas(svc, generated_schema, tags_order='prepend', inject=True):
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
from bentoml._internal.service.openapi.specification import OpenAPISpecification
@@ -623,7 +606,7 @@ def append_schemas(
return svc
svc_schema = svc.openapi_spec
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)):
if isinstance(svc_schema, (OpenAPISpecification, _SimpleSchema)):
svc_schema = svc_schema.asdict()
if 'tags' in generated_schema:
if tags_order == 'prepend':
@@ -639,12 +622,9 @@ def append_schemas(
# HACK: mk this attribute until we have a better way to add starlette schemas.
from bentoml._internal.service import openapi
def mk_generate_spec(svc, openapi_version=OPENAPI_VERSION):
return MKSchema(svc_schema)
def _generate_spec(svc, openapi_version=OPENAPI_VERSION):
return _SimpleSchema(svc_schema)
def mk_asdict(self):
return svc_schema
openapi.generate_spec = mk_generate_spec
OpenAPISpecification.asdict = mk_asdict
openapi.generate_spec = _generate_spec
OpenAPISpecification.asdict = lambda self: svc_schema
return svc