Add iter_route_contexts() for advanced use cases that used to use router.routes (e.g. Jupyverse)

This commit is contained in:
Sebastián Ramírez
2026-06-17 20:38:57 +02:00
parent 202b2d2f5f
commit 29a91ce53e
3 changed files with 186 additions and 15 deletions

View File

@@ -479,26 +479,22 @@ def get_openapi_path(
def _get_api_route_for_openapi(
route: BaseRoute, route_context: routing._EffectiveRouteContext | None
route_context: routing.RouteContext,
) -> routing._APIRouteLike | None:
if route_context is not None and isinstance(
route_context.original_route, routing.APIRoute
):
if isinstance(route_context.original_route, routing.APIRoute):
return cast(routing._APIRouteLike, route_context)
if isinstance(route, routing.APIRoute):
return cast(routing._APIRouteLike, route)
return None
def get_fields_from_routes(
routes: Sequence[BaseRoute],
routes: Sequence[BaseRoute | routing.RouteContext],
) -> list[ModelField]:
body_fields_from_routes: list[ModelField] = []
responses_from_routes: list[ModelField] = []
request_fields_from_routes: list[ModelField] = []
callback_flat_models: list[ModelField] = []
for route, route_context in routing._iter_routes_with_context(routes):
api_route = _get_api_route_for_openapi(route, route_context)
for route_context in routing.iter_route_contexts(routes):
api_route = _get_api_route_for_openapi(route_context)
if api_route is None:
continue
if api_route.include_in_schema:
@@ -531,8 +527,8 @@ def get_openapi(
openapi_version: str = "3.1.0",
summary: str | None = None,
description: str | None = None,
routes: Sequence[BaseRoute],
webhooks: Sequence[BaseRoute] | None = None,
routes: Sequence[BaseRoute | routing.RouteContext],
webhooks: Sequence[BaseRoute | routing.RouteContext] | None = None,
tags: list[dict[str, Any]] | None = None,
servers: list[dict[str, str | Any]] | None = None,
terms_of_service: str | None = None,
@@ -567,8 +563,8 @@ def get_openapi(
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
for route, route_context in routing._iter_routes_with_context(routes):
api_route = _get_api_route_for_openapi(route, route_context)
for route_context in routing.iter_route_contexts(routes):
api_route = _get_api_route_for_openapi(route_context)
if api_route is not None:
result = get_openapi_path(
route=api_route,
@@ -587,8 +583,8 @@ def get_openapi(
)
if path_definitions:
definitions.update(path_definitions)
for webhook, webhook_context in routing._iter_routes_with_context(webhooks or []):
api_webhook = _get_api_route_for_openapi(webhook, webhook_context)
for webhook_context in routing.iter_route_contexts(webhooks or []):
api_webhook = _get_api_route_for_openapi(webhook_context)
if api_webhook is not None:
result = get_openapi_path(
route=api_webhook,

View File

@@ -1454,6 +1454,47 @@ class _EffectiveRouteContext:
return URLPath(path=path, protocol="http")
@dataclass(frozen=True)
class RouteContext:
route: BaseRoute
_route_context: _EffectiveRouteContext | None = field(default=None, repr=False)
@property
def original_route(self) -> BaseRoute:
if self._route_context is not None:
return self._route_context.original_route
return self.route
@property
def _effective_route(self) -> BaseRoute | _EffectiveRouteContext:
if self._route_context is not None:
return self._route_context
return self.route
@property
def path(self) -> str | None:
return getattr(self._effective_route, "path", None)
@property
def path_format(self) -> str | None:
return getattr(self._effective_route, "path_format", None)
@property
def name(self) -> str | None:
return getattr(self._effective_route, "name", None)
@property
def methods(self) -> set[str] | None:
return getattr(self._effective_route, "methods", None)
@property
def endpoint(self) -> Callable[..., Any] | None:
return getattr(self._effective_route, "endpoint", None)
def __getattr__(self, name: str) -> Any:
return getattr(self._effective_route, name)
@dataclass
class _IncludedRouter(BaseRoute):
original_router: "APIRouter"
@@ -1654,6 +1695,20 @@ def _iter_included_route_candidates(routes: Sequence[BaseRoute]) -> Iterator[Bas
yield route
def iter_route_contexts(
routes: Sequence[BaseRoute | RouteContext],
) -> Iterator[RouteContext]:
for route in routes:
if isinstance(route, RouteContext):
yield route
continue
for original_route, route_context in _iter_routes_with_context([route]):
if route_context is None:
yield RouteContext(original_route)
else:
yield RouteContext(original_route, route_context)
def _iter_routes_with_context(
routes: Sequence[BaseRoute],
) -> Iterator[tuple[BaseRoute, _EffectiveRouteContext | None]]:

View File

@@ -3,12 +3,15 @@ from typing import Annotated, cast
import pytest
from fastapi import APIRouter, Body, Depends, FastAPI, Request
from fastapi.exceptions import FastAPIError
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
from fastapi.routing import (
APIRoute,
RouteContext,
_IncludedRouter,
_iter_included_route_candidates,
_restore_fastapi_scope_key,
iter_route_contexts,
)
from fastapi.testclient import TestClient
from starlette.routing import BaseRoute, Host, Match, Mount, NoMatchFound, Route, Router
@@ -30,6 +33,123 @@ def unique_id_b(route: APIRoute) -> str:
return f"b_{route.name}"
def test_iter_route_contexts_returns_direct_route_context():
router = APIRouter()
@router.get("/items/{item_id}")
def read_item(item_id: str):
return {"item_id": item_id}
contexts = list(iter_route_contexts(router.routes))
assert len(contexts) == 1
assert isinstance(contexts[0], RouteContext)
assert contexts[0].original_route is router.routes[0]
assert contexts[0].path == "/items/{item_id}"
assert contexts[0].path_format == "/items/{item_id}"
assert contexts[0].methods == {"GET"}
def test_iter_route_contexts_returns_nested_effective_paths():
leaf_router = APIRouter()
@leaf_router.get("/me")
def read_me():
return {"me": True}
child_router = APIRouter()
child_router.include_router(leaf_router, prefix="/user")
parent_router = APIRouter()
parent_router.include_router(child_router, prefix="/auth")
app = FastAPI()
app.include_router(parent_router, prefix="/api")
contexts = [
context
for context in iter_route_contexts(app.routes)
if getattr(context, "name", None) == "read_me"
]
assert len(contexts) == 1
assert contexts[0].path == "/api/auth/user/me"
assert contexts[0].path_format == "/api/auth/user/me"
assert contexts[0].endpoint is read_me
def test_iter_route_contexts_returns_each_inclusion_of_same_router():
router = APIRouter()
@router.get("/items")
def read_items():
return []
parent_router = APIRouter()
parent_router.include_router(router, prefix="/v1")
parent_router.include_router(router, prefix="/v2")
paths = [
context.path
for context in iter_route_contexts(parent_router.routes)
if getattr(context, "name", None) == "read_items"
]
assert paths == ["/v1/items", "/v2/items"]
def test_iter_route_contexts_supports_nested_conflict_detection():
existing_router = APIRouter()
nested_router = APIRouter()
@nested_router.get("/me")
def read_me():
return {"me": True}
existing_router.include_router(nested_router, prefix="/auth/user")
new_router = APIRouter()
@new_router.get("/auth/user/me")
def read_me_again():
return {"me": False}
existing_paths = {
context.path for context in iter_route_contexts(existing_router.routes)
}
new_paths = {context.path for context in iter_route_contexts(new_router.routes)}
assert existing_paths & new_paths == {"/auth/user/me"}
def test_get_openapi_accepts_filtered_route_contexts_with_effective_paths():
router = APIRouter()
@router.get("/public", tags=["public"])
def read_public():
return {"public": True}
@router.get("/private", tags=["private"])
def read_private():
return {"private": True}
app = FastAPI()
app.include_router(router, prefix="/api")
public_routes = [
context
for context in iter_route_contexts(app.routes)
if "public" in getattr(context, "tags", [])
]
schema = get_openapi(
title="Public API",
version="1.0.0",
routes=public_routes,
)
assert set(schema["paths"]) == {"/api/public"}
def test_router_include_context_matches_flattened_include_metadata():
callback_router = APIRouter()