Compare commits

...

2 Commits

Author SHA1 Message Date
Sebastián Ramírez
1c46aa2d52 🎨 Add types and format for linting 2022-04-13 17:31:23 +02:00
Sebastián Ramírez
edcae918e2 ♻️ Refactor include_router to mount sub-routers 2022-04-13 13:01:42 +02:00
5 changed files with 610 additions and 162 deletions

View File

@@ -137,6 +137,10 @@ class FastAPI(Starlette):
self.middleware_stack: ASGIApp = self.build_middleware_stack() self.middleware_stack: ASGIApp = self.build_middleware_stack()
self.setup() self.setup()
@property
def routes(self) -> List[BaseRoute]:
return list(self.router.iter_all_routes())
def build_middleware_stack(self) -> ASGIApp: def build_middleware_stack(self) -> ASGIApp:
# Duplicate/override from Starlette to add AsyncExitStackMiddleware # Duplicate/override from Starlette to add AsyncExitStackMiddleware
# inside of ExceptionMiddleware, inside of custom user middlewares # inside of ExceptionMiddleware, inside of custom user middlewares

View File

@@ -152,7 +152,7 @@ def generate_operation_id(
) )
if route.operation_id: if route.operation_id:
return route.operation_id return route.operation_id
path: str = route.path_format path: str = route._route_full_path_format
return generate_operation_id_for_path(name=route.name, path=path, method=method) return generate_operation_id_for_path(name=route.name, path=path, method=method)
@@ -243,7 +243,7 @@ def get_openapi_path(
model_name_map=model_name_map, model_name_map=model_name_map,
operation_ids=operation_ids, operation_ids=operation_ids,
) )
callbacks[callback.name] = {callback.path: cb_path} callbacks[callback.name] = {callback._route_full_path: cb_path}
operation["callbacks"] = callbacks operation["callbacks"] = callbacks
if route.status_code is not None: if route.status_code is not None:
status_code = str(route.status_code) status_code = str(route.status_code)
@@ -422,7 +422,7 @@ def get_openapi(
if result: if result:
path, security_schemes, path_definitions = result path, security_schemes, path_definitions = result
if path: if path:
paths.setdefault(route.path_format, {}).update(path) paths.setdefault(route._route_full_path_format, {}).update(path)
if security_schemes: if security_schemes:
components.setdefault("securitySchemes", {}).update( components.setdefault("securitySchemes", {}).update(
security_schemes security_schemes

View File

@@ -9,12 +9,14 @@ from typing import (
Callable, Callable,
Coroutine, Coroutine,
Dict, Dict,
Iterator,
List, List,
Optional, Optional,
Sequence, Sequence,
Set, Set,
Tuple, Tuple,
Type, Type,
TypeVar,
Union, Union,
) )
@@ -57,6 +59,10 @@ from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.types import ASGIApp, Scope from starlette.types import ASGIApp, Scope
from starlette.websockets import WebSocket from starlette.websockets import WebSocket
APIRouteType = TypeVar("APIRouteType", bound="APIRoute")
APIRouterType = TypeVar("APIRouterType", bound="APIRouter")
APIMountType = TypeVar("APIMountType", bound="APIMount")
def _prepare_response_content( def _prepare_response_content(
res: Any, res: Any,
@@ -305,6 +311,8 @@ class APIWebSocketRoute(routing.WebSocketRoute):
class APIRoute(routing.Route): class APIRoute(routing.Route):
_route_full_path_format: str # only for mypy
def __init__( def __init__(
self, self,
path: str, path: str,
@@ -338,13 +346,13 @@ class APIRoute(routing.Route):
generate_unique_id_function: Union[ generate_unique_id_function: Union[
Callable[["APIRoute"], str], DefaultPlaceholder Callable[["APIRoute"], str], DefaultPlaceholder
] = Default(generate_unique_id), ] = Default(generate_unique_id),
router: Optional["APIRouter"] = None,
) -> None: ) -> None:
self.path = path self.path = path
self.endpoint = endpoint self.endpoint = endpoint
self.response_model = response_model self.response_model = response_model
self.summary = summary self.summary = summary
self.response_description = response_description self.response_description = response_description
self.deprecated = deprecated
self.operation_id = operation_id self.operation_id = operation_id
self.response_model_include = response_model_include self.response_model_include = response_model_include
self.response_model_exclude = response_model_exclude self.response_model_exclude = response_model_exclude
@@ -352,34 +360,128 @@ class APIRoute(routing.Route):
self.response_model_exclude_unset = response_model_exclude_unset self.response_model_exclude_unset = response_model_exclude_unset
self.response_model_exclude_defaults = response_model_exclude_defaults self.response_model_exclude_defaults = response_model_exclude_defaults
self.response_model_exclude_none = response_model_exclude_none self.response_model_exclude_none = response_model_exclude_none
self.include_in_schema = include_in_schema
self.response_class = response_class
self.dependency_overrides_provider = dependency_overrides_provider self.dependency_overrides_provider = dependency_overrides_provider
self.callbacks = callbacks
self.openapi_extra = openapi_extra self.openapi_extra = openapi_extra
self.generate_unique_id_function = generate_unique_id_function self.router = router
self.tags = tags or []
self.responses = responses or {}
self.name = get_name(endpoint) if name is None else name self.name = get_name(endpoint) if name is None else name
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
if methods is None:
methods = ["GET"]
self.methods: Set[str] = set([method.upper() for method in methods])
if isinstance(generate_unique_id_function, DefaultPlaceholder):
current_generate_unique_id: Callable[
["APIRoute"], str
] = generate_unique_id_function.value
else:
current_generate_unique_id = generate_unique_id_function
self.unique_id = self.operation_id or current_generate_unique_id(self)
# normalize enums e.g. http.HTTPStatus # normalize enums e.g. http.HTTPStatus
if isinstance(status_code, IntEnum): if isinstance(status_code, IntEnum):
status_code = int(status_code) status_code = int(status_code)
self.status_code = status_code self.status_code = status_code
if methods is None:
methods = ["GET"]
self.methods: Set[str] = set([method.upper() for method in methods])
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text,
# truncate description text to the content preceding the first "form feed"
self.description = self.description.split("\f")[0]
assert callable(endpoint), "An endpoint must be a callable"
self.path_regex, self.path_format, self.param_convertors = compile_path(
self.path
)
# Attributes set in route used to compute resolved attributes
self._route_deprecated = deprecated
self._route_include_in_schema = include_in_schema
self._route_response_class = response_class
self._route_callbacks = callbacks
self._route_generate_unique_id_function = generate_unique_id_function
self._route_tags = tags or []
self._route_responses = responses or {}
if dependencies:
self._route_dependencies = dependencies
else:
self._route_dependencies = []
self.setup()
def setup(self) -> None:
# setup full path
self._route_full_path = self.path
if self.router:
self._route_full_path = self.router._router_full_path + self.path
# setup dependencies
self.dependencies: List[params.Depends] = []
if self.router:
self.dependencies.extend(self.router.dependencies)
self.dependencies.extend(self._route_dependencies)
# setup generate_unique_id
generate_unique_id_functions: List[
Union[Callable[[APIRoute], str], DefaultPlaceholder]
] = [self._route_generate_unique_id_function]
if self.router:
generate_unique_id_functions.append(self.router.generate_unique_id_function)
current_generate_unique_id_function = get_value_or_default(
*generate_unique_id_functions
)
self.generate_unique_id_function: Union[
Callable[[APIRoute], str], DefaultPlaceholder
] = current_generate_unique_id_function
# setup responses
responses: Dict[Union[int, str], Dict[str, Any]] = {}
if self.router:
responses.update(self.router.responses)
responses.update(self._route_responses)
self.responses: Dict[Union[int, str], Dict[str, Any]] = responses
# setup default_response_class
default_response_classes: List[Union[Type[Response], DefaultPlaceholder]] = [
self._route_response_class
]
if self.router:
default_response_classes.append(self.router.default_response_class)
current_default_response_class = get_value_or_default(*default_response_classes)
self.response_class: Union[
Type[Response], DefaultPlaceholder
] = current_default_response_class
# setup tags
self.tags: List[Union[str, Enum]] = []
if self.router:
self.tags.extend(self.router.tags)
self.tags.extend(self._route_tags)
# setup callbacks
callbacks: List[BaseRoute] = []
if self.router:
callbacks.extend(self.router.callbacks)
if self._route_callbacks:
callbacks.extend(self._route_callbacks)
self.callbacks = callbacks
# setup deprecated
self.deprecated = self._route_deprecated
if self.router:
self.deprecated = self._route_deprecated or self.router.deprecated
# setup include_in_schema
self.include_in_schema = self._route_include_in_schema
if self.router:
self.include_in_schema = (
self._route_include_in_schema and self.router.include_in_schema
)
_, self._route_full_path_format, _ = compile_path(self._route_full_path)
if isinstance(self.generate_unique_id_function, DefaultPlaceholder):
resolved_generate_unique_id: Callable[
["APIRoute"], str
] = self.generate_unique_id_function.value
else:
resolved_generate_unique_id = self.generate_unique_id_function
self.unique_id = self.operation_id or resolved_generate_unique_id(self)
if self.response_model: if self.response_model:
assert ( assert (
status_code not in STATUS_CODES_WITH_NO_BODY self.status_code not in STATUS_CODES_WITH_NO_BODY
), f"Status code {status_code} must not have a response body" ), f"Status code {self.status_code} must not have a response body"
response_name = "Response_" + self.unique_id response_name = "Response_" + self.unique_id
self.response_field = create_response_field( self.response_field = create_response_field(
name=response_name, type_=self.response_model name=response_name, type_=self.response_model
@@ -397,14 +499,7 @@ class APIRoute(routing.Route):
else: else:
self.response_field = None # type: ignore self.response_field = None # type: ignore
self.secure_cloned_response_field = None self.secure_cloned_response_field = None
if dependencies:
self.dependencies = list(dependencies)
else:
self.dependencies = []
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text,
# truncate description text to the content preceding the first "form feed"
self.description = self.description.split("\f")[0]
response_fields = {} response_fields = {}
for additional_status_code, response in self.responses.items(): for additional_status_code, response in self.responses.items():
assert isinstance(response, dict), "An additional response must be a dict" assert isinstance(response, dict), "An additional response must be a dict"
@@ -421,16 +516,50 @@ class APIRoute(routing.Route):
else: else:
self.response_fields = {} self.response_fields = {}
assert callable(endpoint), "An endpoint must be a callable" self.dependant = get_dependant(
self.dependant = get_dependant(path=self.path_format, call=self.endpoint) path=self._route_full_path_format, call=self.endpoint
)
for depends in self.dependencies[::-1]: for depends in self.dependencies[::-1]:
self.dependant.dependencies.insert( self.dependant.dependencies.insert(
0, 0,
get_parameterless_sub_dependant(depends=depends, path=self.path_format), get_parameterless_sub_dependant(
depends=depends, path=self._route_full_path_format
),
) )
self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id) self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id)
self.app = request_response(self.get_route_handler()) self.app = request_response(self.get_route_handler())
def copy(self: APIRouteType) -> APIRouteType:
return type(self)(
path=self.path,
endpoint=self.endpoint,
response_model=self.response_model,
status_code=self.status_code,
tags=self._route_tags,
dependencies=self._route_dependencies,
summary=self.summary,
description=self.description,
response_description=self.response_description,
responses=self._route_responses,
deprecated=self._route_deprecated,
name=self.name,
methods=self.methods,
operation_id=self.operation_id,
response_model_include=self.response_model_include,
response_model_exclude=self.response_model_exclude,
response_model_by_alias=self.response_model_by_alias,
response_model_exclude_unset=self.response_model_exclude_unset,
response_model_exclude_defaults=self.response_model_exclude_defaults,
response_model_exclude_none=self.response_model_exclude_none,
include_in_schema=self._route_include_in_schema,
response_class=self._route_response_class,
dependency_overrides_provider=self.dependency_overrides_provider,
callbacks=self._route_callbacks,
openapi_extra=self.openapi_extra,
generate_unique_id_function=self._route_generate_unique_id_function,
router=self.router,
)
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
return get_request_handler( return get_request_handler(
dependant=self.dependant, dependant=self.dependant,
@@ -476,6 +605,7 @@ class APIRouter(routing.Router):
generate_unique_id_function: Callable[[APIRoute], str] = Default( generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id generate_unique_id
), ),
parent_router: Optional["APIRouter"] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
routes=routes, # type: ignore # in Starlette routes=routes, # type: ignore # in Starlette
@@ -490,16 +620,151 @@ class APIRouter(routing.Router):
"/" "/"
), "A path prefix must not end with '/', as the routes will start with '/'" ), "A path prefix must not end with '/', as the routes will start with '/'"
self.prefix = prefix self.prefix = prefix
self.tags: List[Union[str, Enum]] = tags or []
self.dependencies = list(dependencies or []) or []
self.deprecated = deprecated
self.include_in_schema = include_in_schema
self.responses = responses or {}
self.callbacks = callbacks or []
self.dependency_overrides_provider = dependency_overrides_provider self.dependency_overrides_provider = dependency_overrides_provider
self.route_class = route_class self.route_class = route_class
self.default_response_class = default_response_class
self.generate_unique_id_function = generate_unique_id_function self.parent_router = parent_router
# Attributes set in router used to compute resolved attributes
self._router_dependencies = list(dependencies or []) or []
self._router_generate_unique_id_function = generate_unique_id_function
self._router_responses = responses or {}
self._router_default_response_class = default_response_class
self._router_tags: List[Union[str, Enum]] = tags or []
self._router_callbacks = callbacks or []
self._router_deprecated = deprecated
self._router_include_in_schema = include_in_schema
self._router_has_empty_route = False
self._router_has_root_route = False
self.setup()
def setup(self) -> None:
# setup full path
self._router_full_path = self.prefix
if self.parent_router:
self._router_full_path = self.parent_router._router_full_path + self.prefix
# setup dependencies
self.dependencies: List[params.Depends] = []
if self.parent_router:
self.dependencies.extend(self.parent_router.dependencies)
self.dependencies.extend(self._router_dependencies)
# setup generate_unique_id
generate_unique_id_functions: List[
Union[Callable[[APIRoute], str], DefaultPlaceholder]
] = [self._router_generate_unique_id_function]
if self.parent_router:
generate_unique_id_functions.append(
self.parent_router.generate_unique_id_function
)
current_generate_unique_id_function = get_value_or_default(
*generate_unique_id_functions
)
self.generate_unique_id_function: Union[
Callable[[APIRoute], str], DefaultPlaceholder
] = current_generate_unique_id_function
# setup responses
responses: Dict[Union[int, str], Dict[str, Any]] = {}
if self.parent_router:
responses.update(self.parent_router.responses)
responses.update(self._router_responses)
self.responses: Dict[Union[int, str], Dict[str, Any]] = responses
# setup default_response_class
default_response_classes: List[Union[Type[Response], DefaultPlaceholder]] = [
self._router_default_response_class
]
if self.parent_router:
default_response_classes.append(self.parent_router.default_response_class)
current_default_response_class = get_value_or_default(*default_response_classes)
self.default_response_class: Union[
Type[Response], DefaultPlaceholder
] = current_default_response_class
# setup tags
self.tags: List[Union[str, Enum]] = []
if self.parent_router:
self.tags.extend(self.parent_router.tags)
self.tags.extend(self._router_tags)
# setup callbacks
self.callbacks: List[BaseRoute] = []
if self.parent_router:
self.callbacks.extend(self.parent_router.callbacks)
self.callbacks.extend(self._router_callbacks)
# setup deprecated
self.deprecated = self._router_deprecated
if self.parent_router:
self.deprecated = self._router_deprecated or self.parent_router.deprecated
# setup include_in_schema
self.include_in_schema = self._router_include_in_schema
if self.parent_router:
self.include_in_schema = (
self._router_include_in_schema and self.parent_router.include_in_schema
)
# setup routes
for route in self.routes:
if isinstance(route, APIRoute):
route.router = self
route.setup()
elif isinstance(route, APIMount):
route.parent_router = self
route.setup()
def copy(self: APIRouterType) -> APIRouterType:
routes: List[routing.BaseRoute] = []
for route in self.routes:
if isinstance(route, APIRoute):
routes.append(route.copy())
elif isinstance(route, APIMount):
routes.append(route.copy())
else:
routes.append(route)
copied_router = type(self)(
prefix=self.prefix,
tags=self._router_tags,
dependencies=self._router_dependencies,
default_response_class=self._router_default_response_class,
responses=self._router_responses,
callbacks=self._router_callbacks,
routes=routes,
redirect_slashes=self.redirect_slashes,
default=self.default,
dependency_overrides_provider=self.dependency_overrides_provider,
route_class=self.route_class,
on_startup=self.on_startup,
on_shutdown=self.on_shutdown,
deprecated=self._router_deprecated,
include_in_schema=self._router_include_in_schema,
generate_unique_id_function=self._router_generate_unique_id_function,
parent_router=self.parent_router,
)
copied_router._router_has_empty_route = self._router_has_empty_route
copied_router._router_has_root_route = self._router_has_root_route
for route in copied_router.routes:
if isinstance(route, APIRoute):
route.router = copied_router
route.setup()
elif isinstance(route, Mount):
if isinstance(route.app, APIRouter):
route.app.setup()
return copied_router
def iter_all_routes(self) -> Iterator[routing.BaseRoute]:
for route in self.routes:
if isinstance(route, Mount):
if isinstance(route.app, APIRouter):
yield from route.app.iter_all_routes()
else:
yield route
def api_mount(self, router: "APIRouter", name: Optional[str] = None) -> None:
route = APIMount(router=router, name=name, parent_router=self)
self.routes.append(route)
def add_api_route( def add_api_route(
self, self,
@@ -537,34 +802,18 @@ class APIRouter(routing.Router):
) -> None: ) -> None:
route_class = route_class_override or self.route_class route_class = route_class_override or self.route_class
responses = responses or {} responses = responses or {}
combined_responses = {**self.responses, **responses}
current_response_class = get_value_or_default(
response_class, self.default_response_class
)
current_tags = self.tags.copy()
if tags:
current_tags.extend(tags)
current_dependencies = self.dependencies.copy()
if dependencies:
current_dependencies.extend(dependencies)
current_callbacks = self.callbacks.copy()
if callbacks:
current_callbacks.extend(callbacks)
current_generate_unique_id = get_value_or_default(
generate_unique_id_function, self.generate_unique_id_function
)
route = route_class( route = route_class(
self.prefix + path, path,
endpoint=endpoint, endpoint=endpoint,
response_model=response_model, response_model=response_model,
status_code=status_code, status_code=status_code,
tags=current_tags, tags=tags,
dependencies=current_dependencies, dependencies=dependencies,
summary=summary, summary=summary,
description=description, description=description,
response_description=response_description, response_description=response_description,
responses=combined_responses, responses=responses,
deprecated=deprecated or self.deprecated, deprecated=deprecated,
methods=methods, methods=methods,
operation_id=operation_id, operation_id=operation_id,
response_model_include=response_model_include, response_model_include=response_model_include,
@@ -573,15 +822,20 @@ class APIRouter(routing.Router):
response_model_exclude_unset=response_model_exclude_unset, response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults, response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none, response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema and self.include_in_schema, include_in_schema=include_in_schema,
response_class=current_response_class, response_class=response_class,
name=name, name=name,
dependency_overrides_provider=self.dependency_overrides_provider, dependency_overrides_provider=self.dependency_overrides_provider,
callbacks=current_callbacks, callbacks=callbacks,
openapi_extra=openapi_extra, openapi_extra=openapi_extra,
generate_unique_id_function=current_generate_unique_id, generate_unique_id_function=generate_unique_id_function,
router=self,
) )
self.routes.append(route) self.routes.append(route)
if not path:
self._router_has_empty_route = True
if path == "/":
self._router_has_root_route = True
def api_route( def api_route(
self, self,
@@ -680,103 +934,197 @@ class APIRouter(routing.Router):
generate_unique_id_function: Callable[[APIRoute], str] = Default( generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id generate_unique_id
), ),
copy_flat_routes: Optional[bool] = None,
) -> None: ) -> None:
if prefix: if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'" assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith( assert not prefix.endswith(
"/" "/"
), "A path prefix must not end with '/', as the routes will start with '/'" ), "A path prefix must not end with '/', as the routes will start with '/'"
else: resolved_copy_flat_routes = copy_flat_routes
for r in router.routes: if resolved_copy_flat_routes is None:
path = getattr(r, "path") resolved_copy_flat_routes = not (prefix or router.prefix)
name = getattr(r, "name", "unknown") if not resolved_copy_flat_routes:
if path is not None and not path: included_router = router.copy()
raise Exception( if (
f"Prefix and path cannot be both empty (path operation: {name})" prefix
or tags
or dependencies
or not isinstance(default_response_class, DefaultPlaceholder)
or responses
or callbacks
or deprecated is not None
or include_in_schema is not True
or not isinstance(generate_unique_id_function, DefaultPlaceholder)
):
current_router = type(self)(
prefix=prefix,
tags=tags,
dependencies=dependencies,
default_response_class=default_response_class,
responses=responses,
callbacks=callbacks,
deprecated=deprecated,
include_in_schema=include_in_schema,
generate_unique_id_function=generate_unique_id_function,
parent_router=self,
)
# current_router.api_mount(included_router)
current_router.include_router(included_router)
if included_router._router_has_empty_route and not self.prefix:
current_router._router_has_empty_route = True
current_router._router_has_root_route = (
included_router._router_has_root_route
) )
if responses is None: self.api_mount(current_router)
responses = {} included_router.parent_router = current_router
for route in router.routes: else:
if isinstance(route, APIRoute): self.api_mount(included_router)
combined_responses = {**responses, **route.responses} included_router.parent_router = self
use_response_class = get_value_or_default(
route.response_class, included_router.setup()
router.default_response_class, else:
default_response_class, # TODO: remove this and its test, as a subrouter can mount another
self.default_response_class, # subrouter (done automatically of other things are overwritten) and both
) # can omit a prefix, this would error out
current_tags = [] # for r in router.routes:
if tags: # path = getattr(r, "path")
current_tags.extend(tags) # name = getattr(r, "name", "unknown")
if route.tags: # if path is not None and not path:
current_tags.extend(route.tags) # raise Exception(
current_dependencies: List[params.Depends] = [] # f"Prefix and path cannot be both empty (path operation: {name})"
if dependencies: # )
current_dependencies.extend(dependencies) if responses is None:
if route.dependencies: responses = {}
current_dependencies.extend(route.dependencies) for route in router.routes:
current_callbacks = [] if isinstance(route, APIRoute):
if callbacks: combined_responses = {}
current_callbacks.extend(callbacks) if route.router:
if route.callbacks: combined_responses.update(route.router.responses)
current_callbacks.extend(route.callbacks) combined_responses.update(responses)
current_generate_unique_id = get_value_or_default( combined_responses.update(route.responses)
route.generate_unique_id_function,
router.generate_unique_id_function, response_classes: List[
generate_unique_id_function, Union[Type[Response], DefaultPlaceholder]
self.generate_unique_id_function, ] = []
) if route.router:
self.add_api_route( response_classes.append(route.router.default_response_class)
prefix + route.path, response_classes.extend(
route.endpoint, [
response_model=route.response_model, route.response_class,
status_code=route.status_code, router.default_response_class,
tags=current_tags, default_response_class,
dependencies=current_dependencies, self.default_response_class,
summary=route.summary, ]
description=route.description, )
response_description=route.response_description, use_response_class = get_value_or_default(*response_classes)
responses=combined_responses, current_tags = []
deprecated=route.deprecated or deprecated or self.deprecated, if route.router:
methods=route.methods, current_tags.extend(route.router.tags)
operation_id=route.operation_id, if tags:
response_model_include=route.response_model_include, current_tags.extend(tags)
response_model_exclude=route.response_model_exclude, if route.tags:
response_model_by_alias=route.response_model_by_alias, current_tags.extend(route.tags)
response_model_exclude_unset=route.response_model_exclude_unset, current_dependencies: List[params.Depends] = []
response_model_exclude_defaults=route.response_model_exclude_defaults, if route.router:
response_model_exclude_none=route.response_model_exclude_none, current_dependencies.extend(route.router.dependencies)
include_in_schema=route.include_in_schema if dependencies:
and self.include_in_schema current_dependencies.extend(dependencies)
and include_in_schema, if route.dependencies:
response_class=use_response_class, current_dependencies.extend(route.dependencies)
name=route.name, current_callbacks = []
route_class_override=type(route), if route.router:
callbacks=current_callbacks, current_callbacks.extend(route.router.callbacks)
openapi_extra=route.openapi_extra, if callbacks:
generate_unique_id_function=current_generate_unique_id, current_callbacks.extend(callbacks)
) if route.callbacks:
elif isinstance(route, routing.Route): current_callbacks.extend(route.callbacks)
methods = list(route.methods or []) # type: ignore # in Starlette
self.add_route( generate_unique_id_functions: List[
prefix + route.path, Union[Callable[[APIRoute], str], DefaultPlaceholder]
route.endpoint, ] = []
methods=methods, if route.router:
include_in_schema=route.include_in_schema, generate_unique_id_functions.append(
name=route.name, route.router.generate_unique_id_function
) )
elif isinstance(route, APIWebSocketRoute): generate_unique_id_functions.extend(
self.add_api_websocket_route( [
prefix + route.path, route.endpoint, name=route.name route.generate_unique_id_function,
) router.generate_unique_id_function,
elif isinstance(route, routing.WebSocketRoute): generate_unique_id_function,
self.add_websocket_route( self.generate_unique_id_function,
prefix + route.path, route.endpoint, name=route.name ]
) )
for handler in router.on_startup: current_generate_unique_id_function = get_value_or_default(
self.add_event_handler("startup", handler) *generate_unique_id_functions
for handler in router.on_shutdown: )
self.add_event_handler("shutdown", handler) path = prefix + route.path
if route.router:
path = prefix + route.router.prefix + path
self.add_api_route(
path,
route.endpoint,
response_model=route.response_model,
status_code=route.status_code,
tags=current_tags,
dependencies=current_dependencies,
summary=route.summary,
description=route.description,
response_description=route.response_description,
responses=combined_responses,
deprecated=route.deprecated or deprecated or self.deprecated,
methods=route.methods,
operation_id=route.operation_id,
response_model_include=route.response_model_include,
response_model_exclude=route.response_model_exclude,
response_model_by_alias=route.response_model_by_alias,
response_model_exclude_unset=route.response_model_exclude_unset,
response_model_exclude_defaults=route.response_model_exclude_defaults,
response_model_exclude_none=route.response_model_exclude_none,
include_in_schema=route.include_in_schema
and self.include_in_schema
and include_in_schema,
response_class=use_response_class,
name=route.name,
route_class_override=type(route),
callbacks=current_callbacks,
openapi_extra=route.openapi_extra,
generate_unique_id_function=current_generate_unique_id_function,
)
elif isinstance(route, APIMount):
self.include_router(
route.app,
prefix=prefix,
tags=tags,
dependencies=dependencies,
default_response_class=default_response_class,
responses=responses,
callbacks=callbacks,
deprecated=deprecated,
include_in_schema=include_in_schema,
generate_unique_id_function=generate_unique_id_function,
)
elif isinstance(route, routing.Route):
methods = list(route.methods or []) # type: ignore # in Starlette
self.add_route(
prefix + route.path,
route.endpoint,
methods=methods,
include_in_schema=route.include_in_schema,
name=route.name,
)
elif isinstance(route, APIWebSocketRoute):
self.add_api_websocket_route(
prefix + route.path, route.endpoint, name=route.name
)
elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route(
prefix + route.path, route.endpoint, name=route.name
)
for handler in router.on_startup:
self.add_event_handler("startup", handler)
for handler in router.on_shutdown:
self.add_event_handler("shutdown", handler)
def get( def get(
self, self,
@@ -1226,3 +1574,100 @@ class APIRouter(routing.Router):
openapi_extra=openapi_extra, openapi_extra=openapi_extra,
generate_unique_id_function=generate_unique_id_function, generate_unique_id_function=generate_unique_id_function,
) )
class APIMount(routing.Mount):
def __init__(
self,
router: APIRouter,
*,
name: Optional[str] = None,
parent_router: Optional[APIRouter] = None,
) -> None:
self.name = name # type: ignore # in Starlette
self.parent_router = parent_router
self.router = router
self.setup()
def setup(self) -> None:
self.app: APIRouter = self.router.copy()
if self.parent_router:
self.app.parent_router = self.parent_router
self.app.setup()
self.path = self.app.prefix
self.path_regex, self.path_format, self.param_convertors = compile_path(
self.path + "/{path:path}"
)
# Add custom additional root without trailing slash for compatibility with
# include_router and possibly app migrations
# Ref: https://github.com/tiangolo/fastapi/issues/414
(
self._root_path_regex,
self._root_path_format,
self._root_param_convertors,
) = compile_path(self.path)
(
self._root_path_regex_trailing,
self._root_path_format_trailing,
self._root_param_convertors_trailing,
) = compile_path(self.path + "/")
def copy(self: APIMountType) -> APIMountType:
return type(self)(
router=self.router.copy(),
name=self.name,
parent_router=self.parent_router,
)
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
if scope["type"] in ("http", "websocket"):
path = scope["path"]
if self.app._router_has_empty_route:
# Custom logic to support paths without trailing slash
# Ref: https://github.com/tiangolo/fastapi/issues/414
# This mixes the code in
# starlette.routing.Route.matches() and starlette.routing.Mount.matches()
match = self._root_path_regex.match(path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
root_path = scope.get("root_path", "")
child_scope = {
"path_params": path_params,
"app_root_path": scope.get("app_root_path", root_path),
"root_path": root_path,
"path": "",
"endpoint": self.app,
}
return Match.FULL, child_scope
if not self.app._router_has_root_route:
match = self._root_path_regex_trailing.match(path)
if match:
return Match.NONE, {}
# End of custom logic
# Duplicated code from Starlette
match = self.path_regex.match(path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
remaining_path = "/" + matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
root_path = scope.get("root_path", "")
child_scope = {
"path_params": path_params,
"app_root_path": scope.get("app_root_path", root_path),
"root_path": root_path + matched_path,
"path": remaining_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
return Match.NONE, {}
# End of duplicated code from Starlette

View File

@@ -139,7 +139,7 @@ def generate_operation_id_for_path(
def generate_unique_id(route: "APIRoute") -> str: def generate_unique_id(route: "APIRoute") -> str:
operation_id = route.name + route.path_format operation_id = route.name + route._route_full_path_format
operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id) operation_id = re.sub("[^0-9a-zA-Z_]", "_", operation_id)
assert route.methods assert route.methods
operation_id = operation_id + "_" + list(route.methods)[0].lower() operation_id = operation_id + "_" + list(route.methods)[0].lower()

View File

@@ -2,7 +2,6 @@ import pytest
from fastapi import APIRouter, FastAPI from fastapi import APIRouter, FastAPI
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from starlette.routing import Route
app = FastAPI() app = FastAPI()
@@ -107,9 +106,9 @@ def test_get_path(path, expected_status, expected_response):
def test_route_classes(): def test_route_classes():
routes = {} routes = {}
for r in app.router.routes: for r in app.router.iter_all_routes():
assert isinstance(r, Route) if isinstance(r, APIRoute):
routes[r.path] = r routes[r._route_full_path_format] = r
assert getattr(routes["/a/"], "x_type") == "A" assert getattr(routes["/a/"], "x_type") == "A"
assert getattr(routes["/a/b/"], "x_type") == "B" assert getattr(routes["/a/b/"], "x_type") == "B"
assert getattr(routes["/a/b/c/"], "x_type") == "C" assert getattr(routes["/a/b/c/"], "x_type") == "C"