diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index ab4543d346..2e0aca1187 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -479,26 +479,22 @@ def get_openapi_path( def _get_api_route_for_openapi( - route: BaseRoute, route_context: routing._EffectiveRouteContext | None + route_context: routing.RouteContext, ) -> routing._APIRouteLike | None: - if route_context is not None and isinstance( - route_context.original_route, routing.APIRoute - ): + if isinstance(route_context.original_route, routing.APIRoute): return cast(routing._APIRouteLike, route_context) - if isinstance(route, routing.APIRoute): - return cast(routing._APIRouteLike, route) return None def get_fields_from_routes( - routes: Sequence[BaseRoute], + routes: Sequence[BaseRoute | routing.RouteContext], ) -> list[ModelField]: body_fields_from_routes: list[ModelField] = [] responses_from_routes: list[ModelField] = [] request_fields_from_routes: list[ModelField] = [] callback_flat_models: list[ModelField] = [] - for route, route_context in routing._iter_routes_with_context(routes): - api_route = _get_api_route_for_openapi(route, route_context) + for route_context in routing.iter_route_contexts(routes): + api_route = _get_api_route_for_openapi(route_context) if api_route is None: continue if api_route.include_in_schema: @@ -531,8 +527,8 @@ def get_openapi( openapi_version: str = "3.1.0", summary: str | None = None, description: str | None = None, - routes: Sequence[BaseRoute], - webhooks: Sequence[BaseRoute] | None = None, + routes: Sequence[BaseRoute | routing.RouteContext], + webhooks: Sequence[BaseRoute | routing.RouteContext] | None = None, tags: list[dict[str, Any]] | None = None, servers: list[dict[str, str | Any]] | None = None, terms_of_service: str | None = None, @@ -567,8 +563,8 @@ def get_openapi( model_name_map=model_name_map, separate_input_output_schemas=separate_input_output_schemas, ) - for route, route_context in routing._iter_routes_with_context(routes): - api_route = _get_api_route_for_openapi(route, route_context) + for route_context in routing.iter_route_contexts(routes): + api_route = _get_api_route_for_openapi(route_context) if api_route is not None: result = get_openapi_path( route=api_route, @@ -587,8 +583,8 @@ def get_openapi( ) if path_definitions: definitions.update(path_definitions) - for webhook, webhook_context in routing._iter_routes_with_context(webhooks or []): - api_webhook = _get_api_route_for_openapi(webhook, webhook_context) + for webhook_context in routing.iter_route_contexts(webhooks or []): + api_webhook = _get_api_route_for_openapi(webhook_context) if api_webhook is not None: result = get_openapi_path( route=api_webhook, diff --git a/fastapi/routing.py b/fastapi/routing.py index 48c0c21535..4a55fda8a8 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -1454,6 +1454,47 @@ class _EffectiveRouteContext: return URLPath(path=path, protocol="http") +@dataclass(frozen=True) +class RouteContext: + route: BaseRoute + _route_context: _EffectiveRouteContext | None = field(default=None, repr=False) + + @property + def original_route(self) -> BaseRoute: + if self._route_context is not None: + return self._route_context.original_route + return self.route + + @property + def _effective_route(self) -> BaseRoute | _EffectiveRouteContext: + if self._route_context is not None: + return self._route_context + return self.route + + @property + def path(self) -> str | None: + return getattr(self._effective_route, "path", None) + + @property + def path_format(self) -> str | None: + return getattr(self._effective_route, "path_format", None) + + @property + def name(self) -> str | None: + return getattr(self._effective_route, "name", None) + + @property + def methods(self) -> set[str] | None: + return getattr(self._effective_route, "methods", None) + + @property + def endpoint(self) -> Callable[..., Any] | None: + return getattr(self._effective_route, "endpoint", None) + + def __getattr__(self, name: str) -> Any: + return getattr(self._effective_route, name) + + @dataclass class _IncludedRouter(BaseRoute): original_router: "APIRouter" @@ -1654,6 +1695,20 @@ def _iter_included_route_candidates(routes: Sequence[BaseRoute]) -> Iterator[Bas yield route +def iter_route_contexts( + routes: Sequence[BaseRoute | RouteContext], +) -> Iterator[RouteContext]: + for route in routes: + if isinstance(route, RouteContext): + yield route + continue + for original_route, route_context in _iter_routes_with_context([route]): + if route_context is None: + yield RouteContext(original_route) + else: + yield RouteContext(original_route, route_context) + + def _iter_routes_with_context( routes: Sequence[BaseRoute], ) -> Iterator[tuple[BaseRoute, _EffectiveRouteContext | None]]: diff --git a/tests/test_router_include_context.py b/tests/test_router_include_context.py index c2679aa117..3707401b0f 100644 --- a/tests/test_router_include_context.py +++ b/tests/test_router_include_context.py @@ -3,12 +3,15 @@ from typing import Annotated, cast import pytest from fastapi import APIRouter, Body, Depends, FastAPI, Request from fastapi.exceptions import FastAPIError +from fastapi.openapi.utils import get_openapi from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse from fastapi.routing import ( APIRoute, + RouteContext, _IncludedRouter, _iter_included_route_candidates, _restore_fastapi_scope_key, + iter_route_contexts, ) from fastapi.testclient import TestClient from starlette.routing import BaseRoute, Host, Match, Mount, NoMatchFound, Route, Router @@ -30,6 +33,123 @@ def unique_id_b(route: APIRoute) -> str: return f"b_{route.name}" +def test_iter_route_contexts_returns_direct_route_context(): + router = APIRouter() + + @router.get("/items/{item_id}") + def read_item(item_id: str): + return {"item_id": item_id} + + contexts = list(iter_route_contexts(router.routes)) + + assert len(contexts) == 1 + assert isinstance(contexts[0], RouteContext) + assert contexts[0].original_route is router.routes[0] + assert contexts[0].path == "/items/{item_id}" + assert contexts[0].path_format == "/items/{item_id}" + assert contexts[0].methods == {"GET"} + + +def test_iter_route_contexts_returns_nested_effective_paths(): + leaf_router = APIRouter() + + @leaf_router.get("/me") + def read_me(): + return {"me": True} + + child_router = APIRouter() + child_router.include_router(leaf_router, prefix="/user") + + parent_router = APIRouter() + parent_router.include_router(child_router, prefix="/auth") + + app = FastAPI() + app.include_router(parent_router, prefix="/api") + + contexts = [ + context + for context in iter_route_contexts(app.routes) + if getattr(context, "name", None) == "read_me" + ] + + assert len(contexts) == 1 + assert contexts[0].path == "/api/auth/user/me" + assert contexts[0].path_format == "/api/auth/user/me" + assert contexts[0].endpoint is read_me + + +def test_iter_route_contexts_returns_each_inclusion_of_same_router(): + router = APIRouter() + + @router.get("/items") + def read_items(): + return [] + + parent_router = APIRouter() + parent_router.include_router(router, prefix="/v1") + parent_router.include_router(router, prefix="/v2") + + paths = [ + context.path + for context in iter_route_contexts(parent_router.routes) + if getattr(context, "name", None) == "read_items" + ] + + assert paths == ["/v1/items", "/v2/items"] + + +def test_iter_route_contexts_supports_nested_conflict_detection(): + existing_router = APIRouter() + nested_router = APIRouter() + + @nested_router.get("/me") + def read_me(): + return {"me": True} + + existing_router.include_router(nested_router, prefix="/auth/user") + + new_router = APIRouter() + + @new_router.get("/auth/user/me") + def read_me_again(): + return {"me": False} + + existing_paths = { + context.path for context in iter_route_contexts(existing_router.routes) + } + new_paths = {context.path for context in iter_route_contexts(new_router.routes)} + + assert existing_paths & new_paths == {"/auth/user/me"} + + +def test_get_openapi_accepts_filtered_route_contexts_with_effective_paths(): + router = APIRouter() + + @router.get("/public", tags=["public"]) + def read_public(): + return {"public": True} + + @router.get("/private", tags=["private"]) + def read_private(): + return {"private": True} + + app = FastAPI() + app.include_router(router, prefix="/api") + + public_routes = [ + context + for context in iter_route_contexts(app.routes) + if "public" in getattr(context, "tags", []) + ] + schema = get_openapi( + title="Public API", + version="1.0.0", + routes=public_routes, + ) + + assert set(schema["paths"]) == {"/api/public"} + + def test_router_include_context_matches_flattened_include_metadata(): callback_router = APIRouter()