mirror of
https://github.com/fastapi/fastapi.git
synced 2026-06-17 20:09:08 -04:00
✨ Add iter_route_contexts() for advanced use cases that used to use router.routes (e.g. Jupyverse)
This commit is contained in:
@@ -479,26 +479,22 @@ def get_openapi_path(
|
||||
|
||||
|
||||
def _get_api_route_for_openapi(
|
||||
route: BaseRoute, route_context: routing._EffectiveRouteContext | None
|
||||
route_context: routing.RouteContext,
|
||||
) -> routing._APIRouteLike | None:
|
||||
if route_context is not None and isinstance(
|
||||
route_context.original_route, routing.APIRoute
|
||||
):
|
||||
if isinstance(route_context.original_route, routing.APIRoute):
|
||||
return cast(routing._APIRouteLike, route_context)
|
||||
if isinstance(route, routing.APIRoute):
|
||||
return cast(routing._APIRouteLike, route)
|
||||
return None
|
||||
|
||||
|
||||
def get_fields_from_routes(
|
||||
routes: Sequence[BaseRoute],
|
||||
routes: Sequence[BaseRoute | routing.RouteContext],
|
||||
) -> list[ModelField]:
|
||||
body_fields_from_routes: list[ModelField] = []
|
||||
responses_from_routes: list[ModelField] = []
|
||||
request_fields_from_routes: list[ModelField] = []
|
||||
callback_flat_models: list[ModelField] = []
|
||||
for route, route_context in routing._iter_routes_with_context(routes):
|
||||
api_route = _get_api_route_for_openapi(route, route_context)
|
||||
for route_context in routing.iter_route_contexts(routes):
|
||||
api_route = _get_api_route_for_openapi(route_context)
|
||||
if api_route is None:
|
||||
continue
|
||||
if api_route.include_in_schema:
|
||||
@@ -531,8 +527,8 @@ def get_openapi(
|
||||
openapi_version: str = "3.1.0",
|
||||
summary: str | None = None,
|
||||
description: str | None = None,
|
||||
routes: Sequence[BaseRoute],
|
||||
webhooks: Sequence[BaseRoute] | None = None,
|
||||
routes: Sequence[BaseRoute | routing.RouteContext],
|
||||
webhooks: Sequence[BaseRoute | routing.RouteContext] | None = None,
|
||||
tags: list[dict[str, Any]] | None = None,
|
||||
servers: list[dict[str, str | Any]] | None = None,
|
||||
terms_of_service: str | None = None,
|
||||
@@ -567,8 +563,8 @@ def get_openapi(
|
||||
model_name_map=model_name_map,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
for route, route_context in routing._iter_routes_with_context(routes):
|
||||
api_route = _get_api_route_for_openapi(route, route_context)
|
||||
for route_context in routing.iter_route_contexts(routes):
|
||||
api_route = _get_api_route_for_openapi(route_context)
|
||||
if api_route is not None:
|
||||
result = get_openapi_path(
|
||||
route=api_route,
|
||||
@@ -587,8 +583,8 @@ def get_openapi(
|
||||
)
|
||||
if path_definitions:
|
||||
definitions.update(path_definitions)
|
||||
for webhook, webhook_context in routing._iter_routes_with_context(webhooks or []):
|
||||
api_webhook = _get_api_route_for_openapi(webhook, webhook_context)
|
||||
for webhook_context in routing.iter_route_contexts(webhooks or []):
|
||||
api_webhook = _get_api_route_for_openapi(webhook_context)
|
||||
if api_webhook is not None:
|
||||
result = get_openapi_path(
|
||||
route=api_webhook,
|
||||
|
||||
@@ -1454,6 +1454,47 @@ class _EffectiveRouteContext:
|
||||
return URLPath(path=path, protocol="http")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouteContext:
|
||||
route: BaseRoute
|
||||
_route_context: _EffectiveRouteContext | None = field(default=None, repr=False)
|
||||
|
||||
@property
|
||||
def original_route(self) -> BaseRoute:
|
||||
if self._route_context is not None:
|
||||
return self._route_context.original_route
|
||||
return self.route
|
||||
|
||||
@property
|
||||
def _effective_route(self) -> BaseRoute | _EffectiveRouteContext:
|
||||
if self._route_context is not None:
|
||||
return self._route_context
|
||||
return self.route
|
||||
|
||||
@property
|
||||
def path(self) -> str | None:
|
||||
return getattr(self._effective_route, "path", None)
|
||||
|
||||
@property
|
||||
def path_format(self) -> str | None:
|
||||
return getattr(self._effective_route, "path_format", None)
|
||||
|
||||
@property
|
||||
def name(self) -> str | None:
|
||||
return getattr(self._effective_route, "name", None)
|
||||
|
||||
@property
|
||||
def methods(self) -> set[str] | None:
|
||||
return getattr(self._effective_route, "methods", None)
|
||||
|
||||
@property
|
||||
def endpoint(self) -> Callable[..., Any] | None:
|
||||
return getattr(self._effective_route, "endpoint", None)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._effective_route, name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _IncludedRouter(BaseRoute):
|
||||
original_router: "APIRouter"
|
||||
@@ -1654,6 +1695,20 @@ def _iter_included_route_candidates(routes: Sequence[BaseRoute]) -> Iterator[Bas
|
||||
yield route
|
||||
|
||||
|
||||
def iter_route_contexts(
|
||||
routes: Sequence[BaseRoute | RouteContext],
|
||||
) -> Iterator[RouteContext]:
|
||||
for route in routes:
|
||||
if isinstance(route, RouteContext):
|
||||
yield route
|
||||
continue
|
||||
for original_route, route_context in _iter_routes_with_context([route]):
|
||||
if route_context is None:
|
||||
yield RouteContext(original_route)
|
||||
else:
|
||||
yield RouteContext(original_route, route_context)
|
||||
|
||||
|
||||
def _iter_routes_with_context(
|
||||
routes: Sequence[BaseRoute],
|
||||
) -> Iterator[tuple[BaseRoute, _EffectiveRouteContext | None]]:
|
||||
|
||||
@@ -3,12 +3,15 @@ from typing import Annotated, cast
|
||||
import pytest
|
||||
from fastapi import APIRouter, Body, Depends, FastAPI, Request
|
||||
from fastapi.exceptions import FastAPIError
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
|
||||
from fastapi.routing import (
|
||||
APIRoute,
|
||||
RouteContext,
|
||||
_IncludedRouter,
|
||||
_iter_included_route_candidates,
|
||||
_restore_fastapi_scope_key,
|
||||
iter_route_contexts,
|
||||
)
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.routing import BaseRoute, Host, Match, Mount, NoMatchFound, Route, Router
|
||||
@@ -30,6 +33,123 @@ def unique_id_b(route: APIRoute) -> str:
|
||||
return f"b_{route.name}"
|
||||
|
||||
|
||||
def test_iter_route_contexts_returns_direct_route_context():
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/items/{item_id}")
|
||||
def read_item(item_id: str):
|
||||
return {"item_id": item_id}
|
||||
|
||||
contexts = list(iter_route_contexts(router.routes))
|
||||
|
||||
assert len(contexts) == 1
|
||||
assert isinstance(contexts[0], RouteContext)
|
||||
assert contexts[0].original_route is router.routes[0]
|
||||
assert contexts[0].path == "/items/{item_id}"
|
||||
assert contexts[0].path_format == "/items/{item_id}"
|
||||
assert contexts[0].methods == {"GET"}
|
||||
|
||||
|
||||
def test_iter_route_contexts_returns_nested_effective_paths():
|
||||
leaf_router = APIRouter()
|
||||
|
||||
@leaf_router.get("/me")
|
||||
def read_me():
|
||||
return {"me": True}
|
||||
|
||||
child_router = APIRouter()
|
||||
child_router.include_router(leaf_router, prefix="/user")
|
||||
|
||||
parent_router = APIRouter()
|
||||
parent_router.include_router(child_router, prefix="/auth")
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(parent_router, prefix="/api")
|
||||
|
||||
contexts = [
|
||||
context
|
||||
for context in iter_route_contexts(app.routes)
|
||||
if getattr(context, "name", None) == "read_me"
|
||||
]
|
||||
|
||||
assert len(contexts) == 1
|
||||
assert contexts[0].path == "/api/auth/user/me"
|
||||
assert contexts[0].path_format == "/api/auth/user/me"
|
||||
assert contexts[0].endpoint is read_me
|
||||
|
||||
|
||||
def test_iter_route_contexts_returns_each_inclusion_of_same_router():
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/items")
|
||||
def read_items():
|
||||
return []
|
||||
|
||||
parent_router = APIRouter()
|
||||
parent_router.include_router(router, prefix="/v1")
|
||||
parent_router.include_router(router, prefix="/v2")
|
||||
|
||||
paths = [
|
||||
context.path
|
||||
for context in iter_route_contexts(parent_router.routes)
|
||||
if getattr(context, "name", None) == "read_items"
|
||||
]
|
||||
|
||||
assert paths == ["/v1/items", "/v2/items"]
|
||||
|
||||
|
||||
def test_iter_route_contexts_supports_nested_conflict_detection():
|
||||
existing_router = APIRouter()
|
||||
nested_router = APIRouter()
|
||||
|
||||
@nested_router.get("/me")
|
||||
def read_me():
|
||||
return {"me": True}
|
||||
|
||||
existing_router.include_router(nested_router, prefix="/auth/user")
|
||||
|
||||
new_router = APIRouter()
|
||||
|
||||
@new_router.get("/auth/user/me")
|
||||
def read_me_again():
|
||||
return {"me": False}
|
||||
|
||||
existing_paths = {
|
||||
context.path for context in iter_route_contexts(existing_router.routes)
|
||||
}
|
||||
new_paths = {context.path for context in iter_route_contexts(new_router.routes)}
|
||||
|
||||
assert existing_paths & new_paths == {"/auth/user/me"}
|
||||
|
||||
|
||||
def test_get_openapi_accepts_filtered_route_contexts_with_effective_paths():
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/public", tags=["public"])
|
||||
def read_public():
|
||||
return {"public": True}
|
||||
|
||||
@router.get("/private", tags=["private"])
|
||||
def read_private():
|
||||
return {"private": True}
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix="/api")
|
||||
|
||||
public_routes = [
|
||||
context
|
||||
for context in iter_route_contexts(app.routes)
|
||||
if "public" in getattr(context, "tags", [])
|
||||
]
|
||||
schema = get_openapi(
|
||||
title="Public API",
|
||||
version="1.0.0",
|
||||
routes=public_routes,
|
||||
)
|
||||
|
||||
assert set(schema["paths"]) == {"/api/public"}
|
||||
|
||||
|
||||
def test_router_include_context_matches_flattened_include_metadata():
|
||||
callback_router = APIRouter()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user