diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index ab4543d346..2e0aca1187 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -479,26 +479,22 @@ def get_openapi_path( def _get_api_route_for_openapi( - route: BaseRoute, route_context: routing._EffectiveRouteContext | None + route_context: routing.RouteContext, ) -> routing._APIRouteLike | None: - if route_context is not None and isinstance( - route_context.original_route, routing.APIRoute - ): + if isinstance(route_context.original_route, routing.APIRoute): return cast(routing._APIRouteLike, route_context) - if isinstance(route, routing.APIRoute): - return cast(routing._APIRouteLike, route) return None def get_fields_from_routes( - routes: Sequence[BaseRoute], + routes: Sequence[BaseRoute | routing.RouteContext], ) -> list[ModelField]: body_fields_from_routes: list[ModelField] = [] responses_from_routes: list[ModelField] = [] request_fields_from_routes: list[ModelField] = [] callback_flat_models: list[ModelField] = [] - for route, route_context in routing._iter_routes_with_context(routes): - api_route = _get_api_route_for_openapi(route, route_context) + for route_context in routing.iter_route_contexts(routes): + api_route = _get_api_route_for_openapi(route_context) if api_route is None: continue if api_route.include_in_schema: @@ -531,8 +527,8 @@ def get_openapi( openapi_version: str = "3.1.0", summary: str | None = None, description: str | None = None, - routes: Sequence[BaseRoute], - webhooks: Sequence[BaseRoute] | None = None, + routes: Sequence[BaseRoute | routing.RouteContext], + webhooks: Sequence[BaseRoute | routing.RouteContext] | None = None, tags: list[dict[str, Any]] | None = None, servers: list[dict[str, str | Any]] | None = None, terms_of_service: str | None = None, @@ -567,8 +563,8 @@ def get_openapi( model_name_map=model_name_map, separate_input_output_schemas=separate_input_output_schemas, ) - for route, route_context in routing._iter_routes_with_context(routes): - api_route = _get_api_route_for_openapi(route, route_context) + for route_context in routing.iter_route_contexts(routes): + api_route = _get_api_route_for_openapi(route_context) if api_route is not None: result = get_openapi_path( route=api_route, @@ -587,8 +583,8 @@ def get_openapi( ) if path_definitions: definitions.update(path_definitions) - for webhook, webhook_context in routing._iter_routes_with_context(webhooks or []): - api_webhook = _get_api_route_for_openapi(webhook, webhook_context) + for webhook_context in routing.iter_route_contexts(webhooks or []): + api_webhook = _get_api_route_for_openapi(webhook_context) if api_webhook is not None: result = get_openapi_path( route=api_webhook, diff --git a/fastapi/routing.py b/fastapi/routing.py index 48c0c21535..4a55fda8a8 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1454,6 +1454,47 @@ class _EffectiveRouteContext: return URLPath(path=path, protocol="http") +@dataclass(frozen=True) +class RouteContext: + route: BaseRoute + _route_context: _EffectiveRouteContext | None = field(default=None, repr=False) + + @property + def original_route(self) -> BaseRoute: + if self._route_context is not None: + return self._route_context.original_route + return self.route + + @property + def _effective_route(self) -> BaseRoute | _EffectiveRouteContext: + if self._route_context is not None: + return self._route_context + return self.route + + @property + def path(self) -> str | None: + return getattr(self._effective_route, "path", None) + + @property + def path_format(self) -> str | None: + return getattr(self._effective_route, "path_format", None) + + @property + def name(self) -> str | None: + return getattr(self._effective_route, "name", None) + + @property + def methods(self) -> set[str] | None: + return getattr(self._effective_route, "methods", None) + + @property + def endpoint(self) -> Callable[..., Any] | None: + return getattr(self._effective_route, "endpoint", None) + + def __getattr__(self, name: str) -> Any: + return getattr(self._effective_route, name) + + @dataclass class _IncludedRouter(BaseRoute): original_router: "APIRouter" @@ -1654,6 +1695,20 @@ def _iter_included_route_candidates(routes: Sequence[BaseRoute]) -> Iterator[Bas yield route +def iter_route_contexts( + routes: Sequence[BaseRoute | RouteContext], +) -> Iterator[RouteContext]: + for route in routes: + if isinstance(route, RouteContext): + yield route + continue + for original_route, route_context in _iter_routes_with_context([route]): + if route_context is None: + yield RouteContext(original_route) + else: + yield RouteContext(original_route, route_context) + + def _iter_routes_with_context( routes: Sequence[BaseRoute], ) -> Iterator[tuple[BaseRoute, _EffectiveRouteContext | None]]: diff --git a/tests/test_router_include_context.py b/tests/test_router_include_context.py index c2679aa117..cb8dc81fa9 100644 --- a/tests/test_router_include_context.py +++ b/tests/test_router_include_context.py @@ -1,16 +1,21 @@ from typing import Annotated, cast import pytest -from fastapi import APIRouter, Body, Depends, FastAPI, Request +from fastapi import APIRouter, Body, Depends, FastAPI, Request, Security from fastapi.exceptions import FastAPIError +from fastapi.openapi.utils import get_openapi from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse from fastapi.routing import ( APIRoute, + RouteContext, _IncludedRouter, _iter_included_route_candidates, _restore_fastapi_scope_key, + iter_route_contexts, ) +from fastapi.security import HTTPBearer from fastapi.testclient import TestClient +from pydantic import BaseModel from starlette.routing import BaseRoute, Host, Match, Mount, NoMatchFound, Route, Router @@ -30,6 +35,104 @@ def unique_id_b(route: APIRoute) -> str: return f"b_{route.name}" +def test_iter_route_contexts_returns_direct_route_context(): + router = APIRouter() + + @router.get("/items/{item_id}") + def read_item(item_id: str): # pragma: no cover + return {"item_id": item_id} + + contexts = list(iter_route_contexts(router.routes)) + + assert len(contexts) == 1 + assert isinstance(contexts[0], RouteContext) + assert contexts[0].original_route is router.routes[0] + assert contexts[0].path == "/items/{item_id}" + assert contexts[0].path_format == "/items/{item_id}" + assert contexts[0].methods == {"GET"} + assert contexts[0].endpoint is read_item + + +def test_iter_route_contexts_supports_nested_conflict_detection(): + existing_router = APIRouter() + nested_router = APIRouter() + + @nested_router.get("/{username}") + def read_user(username: str): # pragma: no cover + return {"username": username} + + existing_router.include_router(nested_router, prefix="/auth/user") + + new_router = APIRouter() + + @new_router.get("/auth/user/{username}") + def read_user_again(username: str): # pragma: no cover + return {"username": username} + + existing_paths = { + context.path for context in iter_route_contexts(existing_router.routes) + } + new_paths = {context.path for context in iter_route_contexts(new_router.routes)} + + assert existing_paths & new_paths == {"/auth/user/{username}"} + + +def test_get_openapi_accepts_filtered_route_contexts_with_effective_paths(): + router = APIRouter() + bearer_scheme = HTTPBearer() + + @router.get("/public", tags=["public"]) + def read_public(token: Annotated[str, Security(bearer_scheme)]): # pragma: no cover + return {"public": True} + + @router.get("/private", tags=["private"]) + def read_private(): # pragma: no cover + return {"private": True} + + app = FastAPI() + app.include_router(router, prefix="/api") + + public_routes = [ + context + for context in iter_route_contexts(app.routes) + if "public" in getattr(context, "tags", []) + ] + schema = get_openapi( + title="Public API", + version="1.0.0", + routes=public_routes, + ) + + assert set(schema["paths"]) == {"/api/public"} + assert "HTTPBearer" in schema["components"]["securitySchemes"] + + +def test_get_openapi_accepts_webhook_route_contexts(): + app = FastAPI() + bearer_scheme = HTTPBearer() + + class Subscription(BaseModel): + username: str + + @app.webhooks.post("new-subscription") + def new_subscription( + body: Subscription, token: Annotated[str, Security(bearer_scheme)] + ): # pragma: no cover + return None + + webhook_contexts = list(iter_route_contexts(app.webhooks.routes)) + schema = get_openapi( + title="Webhook API", + version="1.0.0", + routes=[], + webhooks=webhook_contexts, + ) + + assert set(schema["webhooks"]) == {"new-subscription"} + assert "HTTPBearer" in schema["components"]["securitySchemes"] + assert "Subscription" in schema["components"]["schemas"] + + def test_router_include_context_matches_flattened_include_metadata(): callback_router = APIRouter()