mirror of
https://github.com/fastapi/fastapi.git
synced 2026-03-04 23:18:35 -05:00
✨ Improve type annotations, add support for mypy --strict, internally and for external packages (#2547)
This commit is contained in:
committed by
GitHub
parent
4fdcdf341c
commit
fdb6c9ccc5
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from starlette.responses import HTMLResponse
|
||||
@@ -13,7 +13,7 @@ def get_swagger_ui_html(
|
||||
swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css",
|
||||
swagger_favicon_url: str = "https://fastapi.tiangolo.com/img/favicon.png",
|
||||
oauth2_redirect_url: Optional[str] = None,
|
||||
init_oauth: Optional[dict] = None,
|
||||
init_oauth: Optional[Dict[str, Any]] = None,
|
||||
) -> HTMLResponse:
|
||||
|
||||
html = f"""
|
||||
|
||||
@@ -5,7 +5,7 @@ from fastapi.logger import logger
|
||||
from pydantic import AnyUrl, BaseModel, Field
|
||||
|
||||
try:
|
||||
import email_validator
|
||||
import email_validator # type: ignore
|
||||
|
||||
assert email_validator # make autoflake ignore the unused import
|
||||
from pydantic import EmailStr
|
||||
@@ -13,7 +13,7 @@ except ImportError: # pragma: no cover
|
||||
|
||||
class EmailStr(str): # type: ignore
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Iterable[Callable]:
|
||||
def __get_validators__(cls) -> Iterable[Callable[..., Any]]:
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -14,6 +14,7 @@ from fastapi.openapi.constants import (
|
||||
)
|
||||
from fastapi.openapi.models import OpenAPI
|
||||
from fastapi.params import Body, Param
|
||||
from fastapi.responses import Response
|
||||
from fastapi.utils import (
|
||||
deep_dict_update,
|
||||
generate_operation_id_for_path,
|
||||
@@ -64,7 +65,9 @@ status_code_ranges: Dict[str, str] = {
|
||||
}
|
||||
|
||||
|
||||
def get_openapi_security_definitions(flat_dependant: Dependant) -> Tuple[Dict, List]:
|
||||
def get_openapi_security_definitions(
|
||||
flat_dependant: Dependant,
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
security_definitions = {}
|
||||
operation_security = []
|
||||
for security_requirement in flat_dependant.security_requirements:
|
||||
@@ -88,13 +91,12 @@ def get_openapi_operation_parameters(
|
||||
for param in all_route_params:
|
||||
field_info = param.field_info
|
||||
field_info = cast(Param, field_info)
|
||||
# ignore mypy error until enum schemas are released
|
||||
parameter = {
|
||||
"name": param.alias,
|
||||
"in": field_info.in_.value,
|
||||
"required": param.required,
|
||||
"schema": field_schema(
|
||||
param, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
|
||||
param, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||
)[0],
|
||||
}
|
||||
if field_info.description:
|
||||
@@ -109,13 +111,12 @@ def get_openapi_operation_request_body(
|
||||
*,
|
||||
body_field: Optional[ModelField],
|
||||
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
|
||||
) -> Optional[Dict]:
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if not body_field:
|
||||
return None
|
||||
assert isinstance(body_field, ModelField)
|
||||
# ignore mypy error until enum schemas are released
|
||||
body_schema, _, _ = field_schema(
|
||||
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
|
||||
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||||
)
|
||||
field_info = cast(Body, body_field.field_info)
|
||||
request_media_type = field_info.media_type
|
||||
@@ -140,7 +141,9 @@ def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
|
||||
return route.name.replace("_", " ").title()
|
||||
|
||||
|
||||
def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> Dict:
|
||||
def get_openapi_operation_metadata(
|
||||
*, route: routing.APIRoute, method: str
|
||||
) -> Dict[str, Any]:
|
||||
operation: Dict[str, Any] = {}
|
||||
if route.tags:
|
||||
operation["tags"] = route.tags
|
||||
@@ -154,14 +157,14 @@ def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str) -> D
|
||||
|
||||
|
||||
def get_openapi_path(
|
||||
*, route: routing.APIRoute, model_name_map: Dict[Type, str]
|
||||
) -> Tuple[Dict, Dict, Dict]:
|
||||
*, route: routing.APIRoute, model_name_map: Dict[type, str]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
path = {}
|
||||
security_schemes: Dict[str, Any] = {}
|
||||
definitions: Dict[str, Any] = {}
|
||||
assert route.methods is not None, "Methods must be a list"
|
||||
if isinstance(route.response_class, DefaultPlaceholder):
|
||||
current_response_class: Type[routing.Response] = route.response_class.value
|
||||
current_response_class: Type[Response] = route.response_class.value
|
||||
else:
|
||||
current_response_class = route.response_class
|
||||
assert current_response_class, "A response class is needed to generate OpenAPI"
|
||||
@@ -169,7 +172,7 @@ def get_openapi_path(
|
||||
if route.include_in_schema:
|
||||
for method in route.methods:
|
||||
operation = get_openapi_operation_metadata(route=route, method=method)
|
||||
parameters: List[Dict] = []
|
||||
parameters: List[Dict[str, Any]] = []
|
||||
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
|
||||
security_definitions, operation_security = get_openapi_security_definitions(
|
||||
flat_dependant=flat_dependant
|
||||
@@ -196,10 +199,15 @@ def get_openapi_path(
|
||||
if route.callbacks:
|
||||
callbacks = {}
|
||||
for callback in route.callbacks:
|
||||
cb_path, cb_security_schemes, cb_definitions, = get_openapi_path(
|
||||
route=callback, model_name_map=model_name_map
|
||||
)
|
||||
callbacks[callback.name] = {callback.path: cb_path}
|
||||
if isinstance(callback, routing.APIRoute):
|
||||
(
|
||||
cb_path,
|
||||
cb_security_schemes,
|
||||
cb_definitions,
|
||||
) = get_openapi_path(
|
||||
route=callback, model_name_map=model_name_map
|
||||
)
|
||||
callbacks[callback.name] = {callback.path: cb_path}
|
||||
operation["callbacks"] = callbacks
|
||||
status_code = str(route.status_code)
|
||||
operation.setdefault("responses", {}).setdefault(status_code, {})[
|
||||
@@ -332,21 +340,19 @@ def get_openapi(
|
||||
routes: Sequence[BaseRoute],
|
||||
tags: Optional[List[Dict[str, Any]]] = None,
|
||||
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
|
||||
) -> Dict:
|
||||
) -> Dict[str, Any]:
|
||||
info = {"title": title, "version": version}
|
||||
if description:
|
||||
info["description"] = description
|
||||
output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
|
||||
if servers:
|
||||
output["servers"] = servers
|
||||
components: Dict[str, Dict] = {}
|
||||
paths: Dict[str, Dict] = {}
|
||||
components: Dict[str, Dict[str, Any]] = {}
|
||||
paths: Dict[str, Dict[str, Any]] = {}
|
||||
flat_models = get_flat_models_from_routes(routes)
|
||||
# ignore mypy error until enum schemas are released
|
||||
model_name_map = get_model_name_map(flat_models) # type: ignore
|
||||
# ignore mypy error until enum schemas are released
|
||||
model_name_map = get_model_name_map(flat_models)
|
||||
definitions = get_model_definitions(
|
||||
flat_models=flat_models, model_name_map=model_name_map # type: ignore
|
||||
flat_models=flat_models, model_name_map=model_name_map
|
||||
)
|
||||
for route in routes:
|
||||
if isinstance(route, routing.APIRoute):
|
||||
@@ -368,4 +374,4 @@ def get_openapi(
|
||||
output["paths"] = paths
|
||||
if tags:
|
||||
output["tags"] = tags
|
||||
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True)
|
||||
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user