infra: using ruff formatter (#594)

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2023-11-09 12:44:05 -05:00
committed by GitHub
parent 021fd453b9
commit ac377fe490
102 changed files with 5577 additions and 2540 deletions

View File

@@ -15,6 +15,7 @@ from starlette.schemas import SchemaGenerator
from openllm_core._typing_compat import ParamSpec
from openllm_core.utils import first_not_none
if t.TYPE_CHECKING:
from attr import AttrsInstance
@@ -23,7 +24,7 @@ if t.TYPE_CHECKING:
P = ParamSpec('P')
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
# NOTE: OpenAI schema
LIST_MODEL_SCHEMA = '''\
LIST_MODEL_SCHEMA = """\
---
consumes:
- application/json
@@ -53,8 +54,8 @@ responses:
owned_by: 'na'
schema:
$ref: '#/components/schemas/ModelList'
'''
CHAT_COMPLETION_SCHEMA = '''\
"""
CHAT_COMPLETION_SCHEMA = """\
---
consumes:
- application/json
@@ -191,8 +192,8 @@ responses:
}
}
description: Bad Request
'''
COMPLETION_SCHEMA = '''\
"""
COMPLETION_SCHEMA = """\
---
consumes:
- application/json
@@ -344,8 +345,8 @@ responses:
}
}
description: Bad Request
'''
HF_AGENT_SCHEMA = '''\
"""
HF_AGENT_SCHEMA = """\
---
consumes:
- application/json
@@ -389,8 +390,8 @@ responses:
schema:
$ref: '#/components/schemas/HFErrorResponse'
description: Not Found
'''
HF_ADAPTERS_SCHEMA = '''\
"""
HF_ADAPTERS_SCHEMA = """\
---
consumes:
- application/json
@@ -420,16 +421,19 @@ responses:
schema:
$ref: '#/components/schemas/HFErrorResponse'
description: Not Found
'''
"""
def add_schema_definitions(append_str: str) -> t.Callable[[t.Callable[P, t.Any]], t.Callable[P, t.Any]]:
def docstring_decorator(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
if func.__doc__ is None: func.__doc__ = ''
if func.__doc__ is None:
func.__doc__ = ''
func.__doc__ = func.__doc__.strip() + '\n\n' + append_str.strip()
return func
return docstring_decorator
class OpenLLMSchemaGenerator(SchemaGenerator):
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
endpoints_info: list[EndpointInfo] = []
@@ -437,20 +441,29 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
if isinstance(route, (Mount, Host)):
routes = route.routes or []
path = self._remove_converter(route.path) if isinstance(route, Mount) else ''
sub_endpoints = [EndpointInfo(path=f'{path}{sub_endpoint.path}', http_method=sub_endpoint.http_method, func=sub_endpoint.func) for sub_endpoint in self.get_endpoints(routes)]
sub_endpoints = [
EndpointInfo(path=f'{path}{sub_endpoint.path}', http_method=sub_endpoint.http_method, func=sub_endpoint.func)
for sub_endpoint in self.get_endpoints(routes)
]
endpoints_info.extend(sub_endpoints)
elif not isinstance(route, Route) or not route.include_in_schema:
continue
elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint) or isinstance(route.endpoint, functools.partial):
elif (
inspect.isfunction(route.endpoint)
or inspect.ismethod(route.endpoint)
or isinstance(route.endpoint, functools.partial)
):
endpoint = route.endpoint.func if isinstance(route.endpoint, functools.partial) else route.endpoint
path = self._remove_converter(route.path)
for method in route.methods or ['GET']:
if method == 'HEAD': continue
if method == 'HEAD':
continue
endpoints_info.append(EndpointInfo(path, method.lower(), endpoint))
else:
path = self._remove_converter(route.path)
for method in ['get', 'post', 'put', 'patch', 'delete', 'options']:
if not hasattr(route.endpoint, method): continue
if not hasattr(route.endpoint, method):
continue
func = getattr(route.endpoint, method)
endpoints_info.append(EndpointInfo(path, method.lower(), func))
return endpoints_info
@@ -459,37 +472,52 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
schema = dict(self.base_schema)
schema.setdefault('paths', {})
endpoints_info = self.get_endpoints(routes)
if mount_path: mount_path = f'/{mount_path}' if not mount_path.startswith('/') else mount_path
if mount_path:
mount_path = f'/{mount_path}' if not mount_path.startswith('/') else mount_path
for endpoint in endpoints_info:
parsed = self.parse_docstring(endpoint.func)
if not parsed: continue
if not parsed:
continue
path = endpoint.path if mount_path is None else mount_path + endpoint.path
if path not in schema['paths']: schema['paths'][path] = {}
if path not in schema['paths']:
schema['paths'][path] = {}
schema['paths'][path][endpoint.http_method] = parsed
return schema
def get_generator(title: str, components: list[type[AttrsInstance]] | None = None, tags: list[dict[str, t.Any]] | None = None) -> OpenLLMSchemaGenerator:
def get_generator(
title: str, components: list[type[AttrsInstance]] | None = None, tags: list[dict[str, t.Any]] | None = None
) -> OpenLLMSchemaGenerator:
base_schema: dict[str, t.Any] = dict(info={'title': title, 'version': API_VERSION}, version=OPENAPI_VERSION)
if components: base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
if tags is not None and tags: base_schema['tags'] = tags
if components:
base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
if tags is not None and tags:
base_schema['tags'] = tags
return OpenLLMSchemaGenerator(base_schema)
def component_schema_generator(attr_cls: type[AttrsInstance], description: str | None = None) -> dict[str, t.Any]:
schema: dict[str, t.Any] = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
schema['description'] = first_not_none(getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}')
schema['description'] = first_not_none(
getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}'
)
for field in attr.fields(attr.resolve_types(attr_cls)): # type: ignore[misc,type-var]
attr_type = field.type
origin_type = t.get_origin(attr_type)
args_type = t.get_args(attr_type)
# Map Python types to OpenAPI schema types
if attr_type == str: schema_type = 'string'
elif attr_type == int: schema_type = 'integer'
elif attr_type == float: schema_type = 'number'
elif attr_type == bool: schema_type = 'boolean'
if attr_type == str:
schema_type = 'string'
elif attr_type == int:
schema_type = 'integer'
elif attr_type == float:
schema_type = 'number'
elif attr_type == bool:
schema_type = 'boolean'
elif origin_type is list or origin_type is tuple:
schema_type = 'array'
elif origin_type is dict:
@@ -504,14 +532,18 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str |
else:
schema_type = 'string'
if 'prop_schema' not in locals(): prop_schema = {'type': schema_type}
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory): prop_schema['default'] = field.default # type: ignore[arg-type]
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)): schema['required'].append(field.name)
if 'prop_schema' not in locals():
prop_schema = {'type': schema_type}
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory):
prop_schema['default'] = field.default # type: ignore[arg-type]
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)):
schema['required'].append(field.name)
schema['properties'][field.name] = prop_schema
locals().pop('prop_schema', None)
return schema
class MKSchema:
def __init__(self, it: dict[str, t.Any]) -> None:
self.it = it
@@ -519,19 +551,30 @@ class MKSchema:
def asdict(self) -> dict[str, t.Any]:
return self.it
def append_schemas(svc: bentoml.Service, generated_schema: dict[str, t.Any], tags_order: t.Literal['prepend', 'append'] = 'prepend') -> bentoml.Service:
def append_schemas(
svc: bentoml.Service, generated_schema: dict[str, t.Any], tags_order: t.Literal['prepend', 'append'] = 'prepend'
) -> bentoml.Service:
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
from bentoml._internal.service.openapi.specification import OpenAPISpecification
svc_schema: t.Any = svc.openapi_spec
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)): svc_schema = svc_schema.asdict()
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)):
svc_schema = svc_schema.asdict()
if 'tags' in generated_schema:
if tags_order == 'prepend': svc_schema['tags'] = generated_schema['tags'] + svc_schema['tags']
elif tags_order == 'append': svc_schema['tags'].extend(generated_schema['tags'])
else: raise ValueError(f'Invalid tags_order: {tags_order}')
if 'components' in generated_schema: svc_schema['components']['schemas'].update(generated_schema['components']['schemas'])
if tags_order == 'prepend':
svc_schema['tags'] = generated_schema['tags'] + svc_schema['tags']
elif tags_order == 'append':
svc_schema['tags'].extend(generated_schema['tags'])
else:
raise ValueError(f'Invalid tags_order: {tags_order}')
if 'components' in generated_schema:
svc_schema['components']['schemas'].update(generated_schema['components']['schemas'])
svc_schema['paths'].update(generated_schema['paths'])
from bentoml._internal.service import openapi # HACK: mk this attribute until we have a better way to add starlette schemas.
from bentoml._internal.service import (
openapi, # HACK: mk this attribute until we have a better way to add starlette schemas.
)
# yapf: disable
def mk_generate_spec(svc:bentoml.Service,openapi_version:str=OPENAPI_VERSION)->MKSchema:return MKSchema(svc_schema)