diff --git a/tests/test_router_include_context.py b/tests/test_router_include_context.py index 80a8596d48..b1aac39014 100644 --- a/tests/test_router_include_context.py +++ b/tests/test_router_include_context.py @@ -1,7 +1,7 @@ from typing import Annotated, cast import pytest -from fastapi import APIRouter, Body, Depends, FastAPI, Request +from fastapi import APIRouter, Body, Depends, FastAPI, Request, Security from fastapi.exceptions import FastAPIError from fastapi.openapi.utils import get_openapi from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse @@ -13,7 +13,9 @@ from fastapi.routing import ( _restore_fastapi_scope_key, iter_route_contexts, ) +from fastapi.security import HTTPBearer from fastapi.testclient import TestClient +from pydantic import BaseModel from starlette.routing import BaseRoute, Host, Match, Mount, NoMatchFound, Route, Router @@ -77,9 +79,10 @@ def test_iter_route_contexts_supports_nested_conflict_detection(): def test_get_openapi_accepts_filtered_route_contexts_with_effective_paths(): router = APIRouter() + bearer_scheme = HTTPBearer() @router.get("/public", tags=["public"]) - def read_public(): + def read_public(token: Annotated[str, Security(bearer_scheme)]): return {"public": True} @router.get("/private", tags=["private"]) @@ -101,6 +104,33 @@ def test_get_openapi_accepts_filtered_route_contexts_with_effective_paths(): ) assert set(schema["paths"]) == {"/api/public"} + assert "HTTPBearer" in schema["components"]["securitySchemes"] + + +def test_get_openapi_accepts_webhook_route_contexts(): + app = FastAPI() + bearer_scheme = HTTPBearer() + + class Subscription(BaseModel): + username: str + + @app.webhooks.post("new-subscription") + def new_subscription( + body: Subscription, token: Annotated[str, Security(bearer_scheme)] + ): + return None + + webhook_contexts = list(iter_route_contexts(app.webhooks.routes)) + schema = get_openapi( + title="Webhook API", + version="1.0.0", + routes=[], + webhooks=webhook_contexts, + ) + + assert set(schema["webhooks"]) == {"new-subscription"} + assert "HTTPBearer" in schema["components"]["securitySchemes"] + assert "Subscription" in schema["components"]["schemas"] def test_router_include_context_matches_flattened_include_metadata():