Add iter_route_contexts() for advanced use cases that used to use router.routes (e.g. Jupyverse) (#15785)

This commit is contained in:
Sebastián Ramírez
2026-06-18 08:49:38 +02:00
committed by GitHub
parent 7feb17f80a
commit 6ac122071d
3 changed files with 170 additions and 16 deletions

View File

@@ -479,26 +479,22 @@ def get_openapi_path(
def _get_api_route_for_openapi(
route: BaseRoute, route_context: routing._EffectiveRouteContext | None
route_context: routing.RouteContext,
) -> routing._APIRouteLike | None:
if route_context is not None and isinstance(
route_context.original_route, routing.APIRoute
):
if isinstance(route_context.original_route, routing.APIRoute):
return cast(routing._APIRouteLike, route_context)
if isinstance(route, routing.APIRoute):
return cast(routing._APIRouteLike, route)
return None
def get_fields_from_routes(
routes: Sequence[BaseRoute],
routes: Sequence[BaseRoute | routing.RouteContext],
) -> list[ModelField]:
body_fields_from_routes: list[ModelField] = []
responses_from_routes: list[ModelField] = []
request_fields_from_routes: list[ModelField] = []
callback_flat_models: list[ModelField] = []
for route, route_context in routing._iter_routes_with_context(routes):
api_route = _get_api_route_for_openapi(route, route_context)
for route_context in routing.iter_route_contexts(routes):
api_route = _get_api_route_for_openapi(route_context)
if api_route is None:
continue
if api_route.include_in_schema:
@@ -531,8 +527,8 @@ def get_openapi(
openapi_version: str = "3.1.0",
summary: str | None = None,
description: str | None = None,
routes: Sequence[BaseRoute],
webhooks: Sequence[BaseRoute] | None = None,
routes: Sequence[BaseRoute | routing.RouteContext],
webhooks: Sequence[BaseRoute | routing.RouteContext] | None = None,
tags: list[dict[str, Any]] | None = None,
servers: list[dict[str, str | Any]] | None = None,
terms_of_service: str | None = None,
@@ -567,8 +563,8 @@ def get_openapi(
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
for route, route_context in routing._iter_routes_with_context(routes):
api_route = _get_api_route_for_openapi(route, route_context)
for route_context in routing.iter_route_contexts(routes):
api_route = _get_api_route_for_openapi(route_context)
if api_route is not None:
result = get_openapi_path(
route=api_route,
@@ -587,8 +583,8 @@ def get_openapi(
)
if path_definitions:
definitions.update(path_definitions)
for webhook, webhook_context in routing._iter_routes_with_context(webhooks or []):
api_webhook = _get_api_route_for_openapi(webhook, webhook_context)
for webhook_context in routing.iter_route_contexts(webhooks or []):
api_webhook = _get_api_route_for_openapi(webhook_context)
if api_webhook is not None:
result = get_openapi_path(
route=api_webhook,

View File

@@ -1454,6 +1454,47 @@ class _EffectiveRouteContext:
return URLPath(path=path, protocol="http")
@dataclass(frozen=True)
class RouteContext:
route: BaseRoute
_route_context: _EffectiveRouteContext | None = field(default=None, repr=False)
@property
def original_route(self) -> BaseRoute:
if self._route_context is not None:
return self._route_context.original_route
return self.route
@property
def _effective_route(self) -> BaseRoute | _EffectiveRouteContext:
if self._route_context is not None:
return self._route_context
return self.route
@property
def path(self) -> str | None:
return getattr(self._effective_route, "path", None)
@property
def path_format(self) -> str | None:
return getattr(self._effective_route, "path_format", None)
@property
def name(self) -> str | None:
return getattr(self._effective_route, "name", None)
@property
def methods(self) -> set[str] | None:
return getattr(self._effective_route, "methods", None)
@property
def endpoint(self) -> Callable[..., Any] | None:
return getattr(self._effective_route, "endpoint", None)
def __getattr__(self, name: str) -> Any:
return getattr(self._effective_route, name)
@dataclass
class _IncludedRouter(BaseRoute):
original_router: "APIRouter"
@@ -1654,6 +1695,20 @@ def _iter_included_route_candidates(routes: Sequence[BaseRoute]) -> Iterator[Bas
yield route
def iter_route_contexts(
routes: Sequence[BaseRoute | RouteContext],
) -> Iterator[RouteContext]:
for route in routes:
if isinstance(route, RouteContext):
yield route
continue
for original_route, route_context in _iter_routes_with_context([route]):
if route_context is None:
yield RouteContext(original_route)
else:
yield RouteContext(original_route, route_context)
def _iter_routes_with_context(
routes: Sequence[BaseRoute],
) -> Iterator[tuple[BaseRoute, _EffectiveRouteContext | None]]:

View File

@@ -1,16 +1,21 @@
from typing import Annotated, cast
import pytest
from fastapi import APIRouter, Body, Depends, FastAPI, Request
from fastapi import APIRouter, Body, Depends, FastAPI, Request, Security
from fastapi.exceptions import FastAPIError
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
from fastapi.routing import (
APIRoute,
RouteContext,
_IncludedRouter,
_iter_included_route_candidates,
_restore_fastapi_scope_key,
iter_route_contexts,
)
from fastapi.security import HTTPBearer
from fastapi.testclient import TestClient
from pydantic import BaseModel
from starlette.routing import BaseRoute, Host, Match, Mount, NoMatchFound, Route, Router
@@ -30,6 +35,104 @@ def unique_id_b(route: APIRoute) -> str:
return f"b_{route.name}"
def test_iter_route_contexts_returns_direct_route_context():
router = APIRouter()
@router.get("/items/{item_id}")
def read_item(item_id: str): # pragma: no cover
return {"item_id": item_id}
contexts = list(iter_route_contexts(router.routes))
assert len(contexts) == 1
assert isinstance(contexts[0], RouteContext)
assert contexts[0].original_route is router.routes[0]
assert contexts[0].path == "/items/{item_id}"
assert contexts[0].path_format == "/items/{item_id}"
assert contexts[0].methods == {"GET"}
assert contexts[0].endpoint is read_item
def test_iter_route_contexts_supports_nested_conflict_detection():
existing_router = APIRouter()
nested_router = APIRouter()
@nested_router.get("/{username}")
def read_user(username: str): # pragma: no cover
return {"username": username}
existing_router.include_router(nested_router, prefix="/auth/user")
new_router = APIRouter()
@new_router.get("/auth/user/{username}")
def read_user_again(username: str): # pragma: no cover
return {"username": username}
existing_paths = {
context.path for context in iter_route_contexts(existing_router.routes)
}
new_paths = {context.path for context in iter_route_contexts(new_router.routes)}
assert existing_paths & new_paths == {"/auth/user/{username}"}
def test_get_openapi_accepts_filtered_route_contexts_with_effective_paths():
router = APIRouter()
bearer_scheme = HTTPBearer()
@router.get("/public", tags=["public"])
def read_public(token: Annotated[str, Security(bearer_scheme)]): # pragma: no cover
return {"public": True}
@router.get("/private", tags=["private"])
def read_private(): # pragma: no cover
return {"private": True}
app = FastAPI()
app.include_router(router, prefix="/api")
public_routes = [
context
for context in iter_route_contexts(app.routes)
if "public" in getattr(context, "tags", [])
]
schema = get_openapi(
title="Public API",
version="1.0.0",
routes=public_routes,
)
assert set(schema["paths"]) == {"/api/public"}
assert "HTTPBearer" in schema["components"]["securitySchemes"]
def test_get_openapi_accepts_webhook_route_contexts():
app = FastAPI()
bearer_scheme = HTTPBearer()
class Subscription(BaseModel):
username: str
@app.webhooks.post("new-subscription")
def new_subscription(
body: Subscription, token: Annotated[str, Security(bearer_scheme)]
): # pragma: no cover
return None
webhook_contexts = list(iter_route_contexts(app.webhooks.routes))
schema = get_openapi(
title="Webhook API",
version="1.0.0",
routes=[],
webhooks=webhook_contexts,
)
assert set(schema["webhooks"]) == {"new-subscription"}
assert "HTTPBearer" in schema["components"]["securitySchemes"]
assert "Subscription" in schema["components"]["schemas"]
def test_router_include_context_matches_flattened_include_metadata():
callback_router = APIRouter()