Improve type annotations, add support for mypy --strict, internally and for external packages (#2547)

This commit is contained in:
Sebastián Ramírez
2020-12-20 19:50:00 +01:00
committed by GitHub
parent 4fdcdf341c
commit fdb6c9ccc5
43 changed files with 314 additions and 244 deletions

View File

@@ -1,5 +1,5 @@
import json
from typing import Optional
from typing import Any, Dict, Optional
from fastapi.encoders import jsonable_encoder
from starlette.responses import HTMLResponse
@@ -13,7 +13,7 @@ def get_swagger_ui_html(
swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css",
swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png",
oauth2_redirect_url: Optional[str] = None,
init_oauth: Optional[dict] = None,
init_oauth: Optional[Dict[str, Any]] = None,
) -> HTMLResponse:
html = f"""

View File

@@ -5,7 +5,7 @@ from fastapi.logger import logger
from pydantic import AnyUrl, BaseModel, Field
try:
import email_validator
import email_validator # type: ignore
assert email_validator # make autoflake ignore the unused import
from pydantic import EmailStr
@@ -13,7 +13,7 @@ except ImportError: # pragma: no cover
class EmailStr(str): # type: ignore
@classmethod
def __get_validators__(cls) -> Iterable[Callable]:
def __get_validators__(cls) -> Iterable[Callable[..., Any]]:
yield cls.validate
@classmethod

View File

@@ -14,6 +14,7 @@ from fastapi.openapi.constants import (
)
from fastapi.openapi.models import OpenAPI
from fastapi.params import Body, Param
from fastapi.responses import Response
from fastapi.utils import (
deep_dict_update,
generate_operation_id_for_path,
@@ -64,7 +65,9 @@ status_code_ranges: Dict[str, str] = {
}
def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
def get_openapi_security_definitions(
flat_dependant: Dependant,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
security_definitions = {}
operation_security = []
for security_requirement in flat_dependant.security_requirements:
@@ -88,13 +91,12 @@ def get_openapi_operation_parameters(
for param in all_route_params:
field_info = param.field_info
field_info = cast(Param, field_info)
# ignore mypy error until enum schemas are released
parameter = {
"name": param.alias,
"in": field_info.in_.value,
"required": param.required,
"schema": field_schema(
param, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
param, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)[0],
}
if field_info.description:
@@ -109,13 +111,12 @@ def get_openapi_operation_request_body(
*,
body_field: Optional[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Optional[Dict]:
) -> Optional[Dict[str, Any]]:
if not body_field:
return None
assert isinstance(body_field, ModelField)
# ignore mypy error until enum schemas are released
body_schema, _, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
field_info = cast(Body, body_field.field_info)
request_media_type = field_info.media_type
@@ -140,7 +141,9 @@ def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
return route.name.replace("_", " ").title()
def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> Dict:
def get_openapi_operation_metadata(
*, route: routing.APIRoute, method: str
) -> Dict[str, Any]:
operation: Dict[str, Any] = {}
if route.tags:
operation["tags"] = route.tags
@@ -154,14 +157,14 @@ def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> D
def get_openapi_path(
*, route: routing.APIRoute, model_name_map: Dict[Type, str]
) -> Tuple[Dict, Dict, Dict]:
*, route: routing.APIRoute, model_name_map: Dict[type, str]
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
path = {}
security_schemes: Dict[str, Any] = {}
definitions: Dict[str, Any] = {}
assert route.methods is not None, "Methods must be a list"
if isinstance(route.response_class, DefaultPlaceholder):
current_response_class: Type[routing.Response] = route.response_class.value
current_response_class: Type[Response] = route.response_class.value
else:
current_response_class = route.response_class
assert current_response_class, "A response class is needed to generate OpenAPI"
@@ -169,7 +172,7 @@ def get_openapi_path(
if route.include_in_schema:
for method in route.methods:
operation = get_openapi_operation_metadata(route=route, method=method)
parameters: List[Dict] = []
parameters: List[Dict[str, Any]] = []
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
security_definitions, operation_security = get_openapi_security_definitions(
flat_dependant=flat_dependant
@@ -196,10 +199,15 @@ def get_openapi_path(
if route.callbacks:
callbacks = {}
for callback in route.callbacks:
cb_path, cb_security_schemes, cb_definitions, = get_openapi_path(
route=callback, model_name_map=model_name_map
)
callbacks[callback.name] = {callback.path: cb_path}
if isinstance(callback, routing.APIRoute):
(
cb_path,
cb_security_schemes,
cb_definitions,
) = get_openapi_path(
route=callback, model_name_map=model_name_map
)
callbacks[callback.name] = {callback.path: cb_path}
operation["callbacks"] = callbacks
status_code = str(route.status_code)
operation.setdefault("responses", {}).setdefault(status_code, {})[
@@ -332,21 +340,19 @@ def get_openapi(
routes: Sequence[BaseRoute],
tags: Optional[List[Dict[str, Any]]] = None,
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
) -> Dict:
) -> Dict[str, Any]:
info = {"title": title, "version": version}
if description:
info["description"] = description
output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
if servers:
output["servers"] = servers
components: Dict[str, Dict] = {}
paths: Dict[str, Dict] = {}
components: Dict[str, Dict[str, Any]] = {}
paths: Dict[str, Dict[str, Any]] = {}
flat_models = get_flat_models_from_routes(routes)
# ignore mypy error until enum schemas are released
model_name_map = get_model_name_map(flat_models) # type: ignore
# ignore mypy error until enum schemas are released
model_name_map = get_model_name_map(flat_models)
definitions = get_model_definitions(
flat_models=flat_models, model_name_map=model_name_map # type: ignore
flat_models=flat_models, model_name_map=model_name_map
)
for route in routes:
if isinstance(route, routing.APIRoute):
@@ -368,4 +374,4 @@ def get_openapi(
output["paths"] = paths
if tags:
output["tags"] = tags
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True)
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore