mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-05-03 21:32:46 -04:00
types: update stubs for remaining entrypoints (#667)
* perf(type): static OpenAI types definition Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * feat: add hf types Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> * types: update remaining missing stubs Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -1,21 +1,14 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import inspect
|
||||
import types
|
||||
import typing as t
|
||||
|
||||
import attr
|
||||
from starlette.routing import BaseRoute, Host, Mount, Route
|
||||
from starlette.routing import Host, Mount, Route
|
||||
from starlette.schemas import EndpointInfo, SchemaGenerator
|
||||
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core.utils import first_not_none
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import AttrsInstance
|
||||
|
||||
import bentoml
|
||||
|
||||
P = ParamSpec('P')
|
||||
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
|
||||
# NOTE: OpenAI schema
|
||||
LIST_MODELS_SCHEMA = """\
|
||||
@@ -479,7 +472,7 @@ summary: Creates a model response for the given chat conversation.
|
||||
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}
|
||||
|
||||
|
||||
def add_schema_definitions(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
|
||||
def add_schema_definitions(func):
|
||||
append_str = _SCHEMAS.get(func.__name__.lower(), '')
|
||||
if not append_str:
|
||||
return func
|
||||
@@ -490,8 +483,8 @@ def add_schema_definitions(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
|
||||
|
||||
|
||||
class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
|
||||
endpoints_info: list[EndpointInfo] = []
|
||||
def get_endpoints(self, routes):
|
||||
endpoints_info = []
|
||||
for route in routes:
|
||||
if isinstance(route, (Mount, Host)):
|
||||
routes = route.routes or []
|
||||
@@ -523,7 +516,7 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), func))
|
||||
return endpoints_info
|
||||
|
||||
def get_schema(self, routes: list[BaseRoute], mount_path: str | None = None) -> dict[str, t.Any]:
|
||||
def get_schema(self, routes, mount_path=None):
|
||||
schema = dict(self.base_schema)
|
||||
schema.setdefault('paths', {})
|
||||
endpoints_info = self.get_endpoints(routes)
|
||||
@@ -543,13 +536,8 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
return schema
|
||||
|
||||
|
||||
def get_generator(
|
||||
title: str,
|
||||
components: list[type[AttrsInstance]] | None = None,
|
||||
tags: list[dict[str, t.Any]] | None = None,
|
||||
inject: bool = True,
|
||||
) -> OpenLLMSchemaGenerator:
|
||||
base_schema: dict[str, t.Any] = dict(info={'title': title, 'version': API_VERSION}, version=OPENAPI_VERSION)
|
||||
def get_generator(title, components=None, tags=None, inject=True):
|
||||
base_schema = {'info': {'title': title, 'version': API_VERSION}, 'version': OPENAPI_VERSION}
|
||||
if components and inject:
|
||||
base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
|
||||
if tags is not None and tags and inject:
|
||||
@@ -557,12 +545,12 @@ def get_generator(
|
||||
return OpenLLMSchemaGenerator(base_schema)
|
||||
|
||||
|
||||
def component_schema_generator(attr_cls: type[AttrsInstance], description: str | None = None) -> dict[str, t.Any]:
|
||||
schema: dict[str, t.Any] = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
|
||||
def component_schema_generator(attr_cls, description=None):
|
||||
schema = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
|
||||
schema['description'] = first_not_none(
|
||||
getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}'
|
||||
)
|
||||
for field in attr.fields(attr.resolve_types(attr_cls)): # type: ignore[misc,type-var]
|
||||
for field in attr.fields(attr.resolve_types(attr_cls)):
|
||||
attr_type = field.type
|
||||
origin_type = t.get_origin(attr_type)
|
||||
args_type = t.get_args(attr_type)
|
||||
@@ -593,7 +581,7 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str |
|
||||
if 'prop_schema' not in locals():
|
||||
prop_schema = {'type': schema_type}
|
||||
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory):
|
||||
prop_schema['default'] = field.default # type: ignore[arg-type]
|
||||
prop_schema['default'] = field.default
|
||||
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)):
|
||||
schema['required'].append(field.name)
|
||||
schema['properties'][field.name] = prop_schema
|
||||
@@ -602,20 +590,15 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str |
|
||||
return schema
|
||||
|
||||
|
||||
class MKSchema:
|
||||
def __init__(self, it: dict[str, t.Any]) -> None:
|
||||
self.it = it
|
||||
|
||||
def asdict(self) -> dict[str, t.Any]:
|
||||
return self.it
|
||||
_SimpleSchema = types.new_class(
|
||||
'_SimpleSchema',
|
||||
(object,),
|
||||
{},
|
||||
lambda ns: ns.update({'__init__': lambda self, it: setattr(self, 'it', it), 'asdict': lambda self: self.it}),
|
||||
)
|
||||
|
||||
|
||||
def append_schemas(
|
||||
svc: bentoml.Service,
|
||||
generated_schema: dict[str, t.Any],
|
||||
tags_order: t.Literal['prepend', 'append'] = 'prepend',
|
||||
inject: bool = True,
|
||||
) -> bentoml.Service:
|
||||
def append_schemas(svc, generated_schema, tags_order='prepend', inject=True):
|
||||
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
|
||||
from bentoml._internal.service.openapi.specification import OpenAPISpecification
|
||||
|
||||
@@ -623,7 +606,7 @@ def append_schemas(
|
||||
return svc
|
||||
|
||||
svc_schema = svc.openapi_spec
|
||||
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)):
|
||||
if isinstance(svc_schema, (OpenAPISpecification, _SimpleSchema)):
|
||||
svc_schema = svc_schema.asdict()
|
||||
if 'tags' in generated_schema:
|
||||
if tags_order == 'prepend':
|
||||
@@ -639,12 +622,9 @@ def append_schemas(
|
||||
# HACK: mk this attribute until we have a better way to add starlette schemas.
|
||||
from bentoml._internal.service import openapi
|
||||
|
||||
def mk_generate_spec(svc, openapi_version=OPENAPI_VERSION):
|
||||
return MKSchema(svc_schema)
|
||||
def _generate_spec(svc, openapi_version=OPENAPI_VERSION):
|
||||
return _SimpleSchema(svc_schema)
|
||||
|
||||
def mk_asdict(self):
|
||||
return svc_schema
|
||||
|
||||
openapi.generate_spec = mk_generate_spec
|
||||
OpenAPISpecification.asdict = mk_asdict
|
||||
openapi.generate_spec = _generate_spec
|
||||
OpenAPISpecification.asdict = lambda self: svc_schema
|
||||
return svc
|
||||
|
||||
Reference in New Issue
Block a user