🔧 Add ty configs to check docs sources (#15770)

This commit is contained in:
Sebastián Ramírez
2026-06-15 17:53:46 +02:00
committed by GitHub
parent 4473a0cd91
commit b7de2b7feb
44 changed files with 149 additions and 118 deletions

View File

@@ -45,7 +45,7 @@ repos:
- id: local-ty
name: ty check
entry: uv run ty check fastapi docs_src --force-exclude
entry: uv run ty check
require_serial: true
language: unsupported
pass_filenames: false

View File

@@ -237,7 +237,7 @@ def update_content(*, content_path: Path, new_content: Any) -> bool:
def main() -> None:
logging.basicConfig(level=logging.INFO)
settings = Settings()
settings = Settings() # ty: ignore[missing-argument]
logging.info(f"Using config: {settings.model_dump_json()}")
g = Github(settings.github_token.get_secret_value())
repo = g.get_repo(settings.github_repository)

View File

@@ -24,7 +24,7 @@ class LinkData(BaseModel):
def main() -> None:
logging.basicConfig(level=logging.INFO)
settings = Settings()
settings = Settings() # ty: ignore[missing-argument]
logging.info(f"Using config: {settings.model_dump_json()}")
g = Github(auth=Auth.Token(settings.github_token.get_secret_value()))

View File

@@ -625,14 +625,14 @@ def replace_multiline_code_block(
_line_b_code, line_b_comment = _split_hash_comment(line_b)
res_line = line_b
if line_b_comment:
res_line = res_line.replace(line_b_comment, line_a_comment, 1)
res_line = res_line.replace(line_b_comment, line_a_comment or "", 1)
code_block.append(res_line)
elif block_language in {"console", "json", "slash-style-comments"}:
_line_a_code, line_a_comment = _split_slashes_comment(line_a)
_line_b_code, line_b_comment = _split_slashes_comment(line_b)
res_line = line_b
if line_b_comment:
res_line = res_line.replace(line_b_comment, line_a_comment, 1)
res_line = res_line.replace(line_b_comment, line_a_comment or "", 1)
code_block.append(res_line)
else:
code_block.append(line_b)

View File

@@ -155,7 +155,7 @@ def build_lang(
"""
build_zensical_lang_to_stage(lang)
copy_zensical_stage_to_site(lang)
typer.secho(f"Successfully built docs for: {lang}", color=typer.colors.GREEN)
typer.secho(f"Successfully built docs for: {lang}", fg=typer.colors.GREEN)
def split_markdown_header(markdown: str) -> tuple[str, str]:
@@ -408,7 +408,7 @@ def build_all() -> None:
for lang in langs:
if lang != "en":
copy_zensical_stage_to_site(lang)
typer.secho("Successfully built all docs", color=typer.colors.GREEN)
typer.secho("Successfully built all docs", fg=typer.colors.GREEN)
@app.command()

View File

@@ -22,7 +22,7 @@ class Settings(BaseSettings):
config: dict[str, LabelSettings] | Literal[""] = default_config
settings = Settings()
settings = Settings() # ty: ignore[missing-argument]
if settings.debug:
logging.basicConfig(level=logging.DEBUG)
else:

View File

@@ -4,6 +4,6 @@ set -e
set -x
mypy fastapi
ty check fastapi docs_src --force-exclude
ty check
ruff check fastapi tests docs_src scripts
ruff format fastapi tests --check

View File

@@ -304,7 +304,7 @@ def update_comment(*, settings: Settings, comment_id: str, body: str) -> Comment
def main() -> None:
settings = Settings()
settings = Settings() # ty: ignore[missing-argument]
if settings.debug:
logging.basicConfig(level=logging.DEBUG)
else:
@@ -324,6 +324,7 @@ def main() -> None:
) or settings.number
if number is None:
raise RuntimeError("No PR number available")
number = cast(int, number)
# Avoid race conditions with multiple labels
sleep_time = random.random() * 10 # random number between 0 and 10 seconds

View File

@@ -394,7 +394,7 @@ def update_content(*, content_path: Path, new_content: Any) -> bool:
def main() -> None:
logging.basicConfig(level=logging.INFO)
settings = Settings()
settings = Settings() # ty: ignore[missing-argument]
logging.info(f"Using config: {settings.model_dump_json()}")
rate_limiter.speed_multiplier = settings.speed_multiplier
g = Github(settings.github_token.get_secret_value())

View File

@@ -158,7 +158,7 @@ def update_content(*, content_path: Path, new_content: Any) -> bool:
def main() -> None:
logging.basicConfig(level=logging.INFO)
settings = Settings()
settings = Settings() # ty: ignore[missing-argument]
logging.info(f"Using config: {settings.model_dump_json()}")
g = Github(settings.pr_token.get_secret_value())
repo = g.get_repo(settings.github_repository)

View File

@@ -24,7 +24,7 @@ class Repo(BaseModel):
def main() -> None:
logging.basicConfig(level=logging.INFO)
settings = Settings()
settings = Settings() # ty: ignore[missing-argument]
logging.info(f"Using config: {settings.model_dump_json()}")
g = Github(settings.github_token.get_secret_value(), per_page=100)

View File

@@ -1,3 +1,5 @@
from typing import Any, cast
from fastapi import FastAPI, UploadFile
from fastapi._compat import (
Undefined,
@@ -56,9 +58,15 @@ def test_propagates_pydantic2_model_config():
@app.post("/")
def foo(req: Model) -> dict[str, str | None]:
value = req.value
if isinstance(value, Missing):
value = None
embedded_value = req.embedded_model.value
if isinstance(embedded_value, Missing):
embedded_value = None
return {
"value": req.value or None,
"embedded_value": req.embedded_model.value or None,
"value": value,
"embedded_value": embedded_value,
}
client = TestClient(app)
@@ -100,7 +108,7 @@ def test_serialize_sequence_value_with_optional_list():
"""Test that serialize_sequence_value handles optional lists correctly."""
from fastapi._compat import v2
field_info = FieldInfo(annotation=list[str] | None)
field_info = FieldInfo(annotation=cast(Any, list[str] | None))
field = v2.ModelField(name="items", field_info=field_info)
result = v2.serialize_sequence_value(field=field, value=["a", "b", "c"])
assert result == ["a", "b", "c"]
@@ -111,7 +119,7 @@ def test_serialize_sequence_value_with_optional_list_pipe_union():
"""Test that serialize_sequence_value handles optional lists correctly (with new syntax)."""
from fastapi._compat import v2
field_info = FieldInfo(annotation=list[str] | None)
field_info = FieldInfo(annotation=cast(Any, list[str] | None))
field = v2.ModelField(name="items", field_info=field_info)
result = v2.serialize_sequence_value(field=field, value=["a", "b", "c"])
assert result == ["a", "b", "c"]
@@ -125,7 +133,7 @@ def test_serialize_sequence_value_with_none_first_in_union():
from fastapi._compat import v2
# Use Union[None, list[str]] to ensure None comes first in the union args
field_info = FieldInfo(annotation=Union[None, list[str]]) # noqa: UP007
field_info = FieldInfo(annotation=cast(Any, Union[None, list[str]])) # noqa: UP007
field = v2.ModelField(name="items", field_info=field_info)
result = v2.serialize_sequence_value(field=field, value=["x", "y"])
assert result == ["x", "y"]

View File

@@ -3,6 +3,7 @@ from pathlib import Path
from fastapi import APIRouter, FastAPI, File, UploadFile
from fastapi.exceptions import HTTPException
from fastapi.testclient import TestClient
from starlette.types import ASGIApp
app = FastAPI()
@@ -16,7 +17,7 @@ class ContentSizeLimitMiddleware:
max_content_size (optional): the maximum content size allowed in bytes, None for no limit
"""
def __init__(self, app: APIRouter, max_content_size: int | None = None):
def __init__(self, app: ASGIApp, max_content_size: int | None = None):
self.app = app
self.max_content_size = max_content_size
@@ -31,6 +32,7 @@ class ContentSizeLimitMiddleware:
body_len = len(message.get("body", b""))
received += body_len
assert self.max_content_size is not None
if received > self.max_content_size:
raise HTTPException(
422,

View File

@@ -1,9 +1,10 @@
import io
from pathlib import Path
from typing import cast
import pytest
from fastapi import FastAPI, UploadFile
from fastapi.datastructures import Default
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.testclient import TestClient
@@ -13,8 +14,8 @@ def test_upload_file_invalid_pydantic_v2():
def test_default_placeholder_equals():
placeholder_1 = Default("a")
placeholder_2 = Default("a")
placeholder_1 = cast(DefaultPlaceholder, Default("a"))
placeholder_2 = cast(DefaultPlaceholder, Default("a"))
assert placeholder_1 == placeholder_2
assert placeholder_1.value == placeholder_2.value

View File

@@ -11,7 +11,7 @@ class ORJSONResponse(JSONResponse):
media_type = "application/x-orjson"
def render(self, content: Any) -> bytes:
import orjson
import orjson # ty: ignore[unresolved-import]
return orjson.dumps(content)

View File

@@ -3,7 +3,7 @@ import warnings
import pytest
from fastapi import FastAPI
from fastapi.exceptions import FastAPIDeprecationWarning
from fastapi.responses import ORJSONResponse, UJSONResponse
from fastapi.responses import ORJSONResponse, UJSONResponse # ty: ignore[deprecated]
from fastapi.testclient import TestClient
from pydantic import BaseModel
@@ -21,7 +21,7 @@ class Item(BaseModel):
def _make_orjson_app() -> FastAPI:
with warnings.catch_warnings():
warnings.simplefilter("ignore", FastAPIDeprecationWarning)
app = FastAPI(default_response_class=ORJSONResponse)
app = FastAPI(default_response_class=ORJSONResponse) # ty: ignore[deprecated]
@app.get("/items")
def get_items() -> Item:
@@ -44,7 +44,7 @@ def test_orjson_response_returns_correct_data():
@needs_orjson
def test_orjson_response_emits_deprecation_warning():
with pytest.warns(FastAPIDeprecationWarning, match="ORJSONResponse is deprecated"):
ORJSONResponse(content={"hello": "world"})
ORJSONResponse(content={"hello": "world"}) # ty: ignore[deprecated]
# UJSON
@@ -53,7 +53,7 @@ def test_orjson_response_emits_deprecation_warning():
def _make_ujson_app() -> FastAPI:
with warnings.catch_warnings():
warnings.simplefilter("ignore", FastAPIDeprecationWarning)
app = FastAPI(default_response_class=UJSONResponse)
app = FastAPI(default_response_class=UJSONResponse) # ty: ignore[deprecated]
@app.get("/items")
def get_items() -> Item:
@@ -76,4 +76,4 @@ def test_ujson_response_returns_correct_data():
@needs_ujson
def test_ujson_response_emits_deprecation_warning():
with pytest.warns(FastAPIDeprecationWarning, match="UJSONResponse is deprecated"):
UJSONResponse(content={"hello": "world"})
UJSONResponse(content={"hello": "world"}) # ty: ignore[deprecated]

View File

@@ -13,7 +13,7 @@ class MyUuid:
def __str__(self):
return self.uuid
@property # type: ignore
@property
def __class__(self):
return uuid.UUID

View File

@@ -87,10 +87,10 @@ def test_encode_dict():
def test_encode_dict_include_exclude_list():
pet = {"name": "Firulais", "owner": {"name": "Foo"}}
assert jsonable_encoder(pet) == {"name": "Firulais", "owner": {"name": "Foo"}}
assert jsonable_encoder(pet, include=["name"]) == {"name": "Firulais"}
assert jsonable_encoder(pet, exclude=["owner"]) == {"name": "Firulais"}
assert jsonable_encoder(pet, include=[]) == {}
assert jsonable_encoder(pet, exclude=[]) == {
assert jsonable_encoder(pet, include=["name"]) == {"name": "Firulais"} # ty: ignore[invalid-argument-type]
assert jsonable_encoder(pet, exclude=["owner"]) == {"name": "Firulais"} # ty: ignore[invalid-argument-type]
assert jsonable_encoder(pet, include=[]) == {} # ty: ignore[invalid-argument-type]
assert jsonable_encoder(pet, exclude=[]) == { # ty: ignore[invalid-argument-type]
"name": "Firulais",
"owner": {"name": "Foo"},
}
@@ -176,7 +176,7 @@ def test_encode_model_with_config():
def test_encode_model_with_alias_raises():
with pytest.raises(ValidationError):
ModelWithAlias(foo="Bar")
ModelWithAlias(foo="Bar") # ty: ignore[missing-argument, unknown-argument]
def test_encode_model_with_alias():

View File

@@ -9,7 +9,7 @@ def test_strings_in_generated_swagger():
swagger_css_url = sig.parameters.get("swagger_css_url").default # type: ignore
swagger_favicon_url = sig.parameters.get("swagger_favicon_url").default # type: ignore
html = get_swagger_ui_html(openapi_url="/docs", title="title")
body_content = html.body.decode()
body_content = bytes(html.body).decode()
assert swagger_js_url in body_content
assert swagger_css_url in body_content
assert swagger_favicon_url in body_content
@@ -26,7 +26,7 @@ def test_strings_in_custom_swagger():
swagger_css_url=swagger_css_url,
swagger_favicon_url=swagger_favicon_url,
)
body_content = html.body.decode()
body_content = bytes(html.body).decode()
assert swagger_js_url in body_content
assert swagger_css_url in body_content
assert swagger_favicon_url in body_content
@@ -37,7 +37,7 @@ def test_strings_in_generated_redoc():
redoc_js_url = sig.parameters.get("redoc_js_url").default # type: ignore
redoc_favicon_url = sig.parameters.get("redoc_favicon_url").default # type: ignore
html = get_redoc_html(openapi_url="/docs", title="title")
body_content = html.body.decode()
body_content = bytes(html.body).decode()
assert redoc_js_url in body_content
assert redoc_favicon_url in body_content
@@ -51,17 +51,17 @@ def test_strings_in_custom_redoc():
redoc_js_url=redoc_js_url,
redoc_favicon_url=redoc_favicon_url,
)
body_content = html.body.decode()
body_content = bytes(html.body).decode()
assert redoc_js_url in body_content
assert redoc_favicon_url in body_content
def test_google_fonts_in_generated_redoc():
body_with_google_fonts = get_redoc_html(
openapi_url="/docs", title="title"
).body.decode()
body_with_google_fonts = bytes(
get_redoc_html(openapi_url="/docs", title="title").body
).decode()
assert "fonts.googleapis.com" in body_with_google_fonts
body_without_google_fonts = get_redoc_html(
openapi_url="/docs", title="title", with_google_fonts=False
).body.decode()
body_without_google_fonts = bytes(
get_redoc_html(openapi_url="/docs", title="title", with_google_fonts=False).body
).decode()
assert "fonts.googleapis.com" not in body_without_google_fonts

View File

@@ -21,4 +21,4 @@ def test_allowed_schema_type(
def test_invalid_type_value() -> None:
"""Test that Schema raises ValueError for invalid type values."""
with pytest.raises(ValueError, match="2 validation errors for Schema"):
Schema(type=True) # type: ignore[arg-type]
Schema(type=True) # type: ignore[arg-type] # ty: ignore[invalid-argument-type]

View File

@@ -6,13 +6,13 @@ pytest.importorskip("orjson")
from fastapi import FastAPI
from fastapi.exceptions import FastAPIDeprecationWarning
from fastapi.responses import ORJSONResponse
from fastapi.responses import ORJSONResponse # ty: ignore[deprecated]
from fastapi.testclient import TestClient
from sqlalchemy.sql.elements import quoted_name
with warnings.catch_warnings():
warnings.simplefilter("ignore", FastAPIDeprecationWarning)
app = FastAPI(default_response_class=ORJSONResponse)
app = FastAPI(default_response_class=ORJSONResponse) # ty: ignore[deprecated]
@app.get("/orjson_non_str_keys")

View File

@@ -78,22 +78,22 @@ def no_response_model_annotation_return_same_model() -> User:
@app.get("/no_response_model-annotation-return_exact_dict")
def no_response_model_annotation_return_exact_dict() -> User:
return {"name": "John", "surname": "Doe"}
return {"name": "John", "surname": "Doe"} # ty: ignore[invalid-return-type]
@app.get("/no_response_model-annotation-return_invalid_dict")
def no_response_model_annotation_return_invalid_dict() -> User:
return {"name": "John"}
return {"name": "John"} # ty: ignore[invalid-return-type]
@app.get("/no_response_model-annotation-return_invalid_model")
def no_response_model_annotation_return_invalid_model() -> User:
return Item(name="Foo", price=42.0)
return Item(name="Foo", price=42.0) # ty: ignore[invalid-return-type]
@app.get("/no_response_model-annotation-return_dict_with_extra_data")
def no_response_model_annotation_return_dict_with_extra_data() -> User:
return {"name": "John", "surname": "Doe", "password_hash": "secret"}
return {"name": "John", "surname": "Doe", "password_hash": "secret"} # ty: ignore[invalid-return-type]
@app.get("/no_response_model-annotation-return_submodel_with_extra_data")
@@ -108,24 +108,24 @@ def response_model_none_annotation_return_same_model() -> User:
@app.get("/response_model_none-annotation-return_exact_dict", response_model=None)
def response_model_none_annotation_return_exact_dict() -> User:
return {"name": "John", "surname": "Doe"}
return {"name": "John", "surname": "Doe"} # ty: ignore[invalid-return-type]
@app.get("/response_model_none-annotation-return_invalid_dict", response_model=None)
def response_model_none_annotation_return_invalid_dict() -> User:
return {"name": "John"}
return {"name": "John"} # ty: ignore[invalid-return-type]
@app.get("/response_model_none-annotation-return_invalid_model", response_model=None)
def response_model_none_annotation_return_invalid_model() -> User:
return Item(name="Foo", price=42.0)
return Item(name="Foo", price=42.0) # ty: ignore[invalid-return-type]
@app.get(
"/response_model_none-annotation-return_dict_with_extra_data", response_model=None
)
def response_model_none_annotation_return_dict_with_extra_data() -> User:
return {"name": "John", "surname": "Doe", "password_hash": "secret"}
return {"name": "John", "surname": "Doe", "password_hash": "secret"} # ty: ignore[invalid-return-type]
@app.get(
@@ -140,21 +140,21 @@ def response_model_none_annotation_return_submodel_with_extra_data() -> User:
"/response_model_model1-annotation_model2-return_same_model", response_model=User
)
def response_model_model1_annotation_model2_return_same_model() -> Item:
return User(name="John", surname="Doe")
return User(name="John", surname="Doe") # ty: ignore[invalid-return-type]
@app.get(
"/response_model_model1-annotation_model2-return_exact_dict", response_model=User
)
def response_model_model1_annotation_model2_return_exact_dict() -> Item:
return {"name": "John", "surname": "Doe"}
return {"name": "John", "surname": "Doe"} # ty: ignore[invalid-return-type]
@app.get(
"/response_model_model1-annotation_model2-return_invalid_dict", response_model=User
)
def response_model_model1_annotation_model2_return_invalid_dict() -> Item:
return {"name": "John"}
return {"name": "John"} # ty: ignore[invalid-return-type]
@app.get(
@@ -169,7 +169,7 @@ def response_model_model1_annotation_model2_return_invalid_model() -> Item:
response_model=User,
)
def response_model_model1_annotation_model2_return_dict_with_extra_data() -> Item:
return {"name": "John", "surname": "Doe", "password_hash": "secret"}
return {"name": "John", "surname": "Doe", "password_hash": "secret"} # ty: ignore[invalid-return-type]
@app.get(
@@ -177,7 +177,7 @@ def response_model_model1_annotation_model2_return_dict_with_extra_data() -> Ite
response_model=User,
)
def response_model_model1_annotation_model2_return_submodel_with_extra_data() -> Item:
return DBUser(name="John", surname="Doe", password_hash="secret")
return DBUser(name="John", surname="Doe", password_hash="secret") # ty: ignore[invalid-return-type]
@app.get(

View File

@@ -31,31 +31,31 @@ def test_router_events(state: State) -> None:
def main() -> dict[str, str]:
return {"message": "Hello World"}
@app.on_event("startup")
@app.on_event("startup") # ty: ignore[deprecated]
def app_startup() -> None:
state.app_startup = True
@app.on_event("shutdown")
@app.on_event("shutdown") # ty: ignore[deprecated]
def app_shutdown() -> None:
state.app_shutdown = True
router = APIRouter()
@router.on_event("startup")
@router.on_event("startup") # ty: ignore[deprecated]
def router_startup() -> None:
state.router_startup = True
@router.on_event("shutdown")
@router.on_event("shutdown") # ty: ignore[deprecated]
def router_shutdown() -> None:
state.router_shutdown = True
sub_router = APIRouter()
@sub_router.on_event("startup")
@sub_router.on_event("startup") # ty: ignore[deprecated]
def sub_router_startup() -> None:
state.sub_router_startup = True
@sub_router.on_event("shutdown")
@sub_router.on_event("shutdown") # ty: ignore[deprecated]
def sub_router_shutdown() -> None:
state.sub_router_shutdown = True
@@ -253,7 +253,7 @@ def test_router_async_shutdown_handler(state: State) -> None:
def main() -> dict[str, str]:
return {"message": "Hello World"}
@app.on_event("shutdown")
@app.on_event("shutdown") # ty: ignore[deprecated]
async def app_shutdown() -> None:
state.app_shutdown = True
@@ -274,7 +274,7 @@ def test_router_sync_generator_lifespan(state: State) -> None:
yield
state.app_shutdown = True
app = FastAPI(lifespan=lifespan) # type: ignore[arg-type]
app = FastAPI(lifespan=lifespan) # type: ignore[invalid-argument-type] # ty: ignore[invalid-argument-type]
@app.get("/")
def main() -> dict[str, str]:
@@ -300,7 +300,7 @@ def test_router_async_generator_lifespan(state: State) -> None:
yield
state.app_shutdown = True
app = FastAPI(lifespan=lifespan) # type: ignore[arg-type]
app = FastAPI(lifespan=lifespan) # type: ignore[invalid-argument-type] # ty: ignore[invalid-argument-type]
@app.get("/")
def main() -> dict[str, str]:

View File

@@ -26,7 +26,7 @@ def get_client():
@app.get("/users")
async def get_user() -> User:
return {"username": "alice", "role": "admin"}
return {"username": "alice", "role": "admin"} # ty: ignore[invalid-return-type]
client = TestClient(app)
return client

View File

@@ -18,7 +18,7 @@ def get_valid():
@app.get("/items/coerce", response_model=Item)
def get_coerce():
return Item(aliased_name="coerce", price="1.0")
return Item(aliased_name="coerce", price="1.0") # ty: ignore[invalid-argument-type]
@app.get("/items/validlist", response_model=list[Item])
@@ -52,7 +52,7 @@ def get_valid_exclude_unset():
response_model_exclude_unset=True,
)
def get_coerce_exclude_unset():
return Item(aliased_name="coerce", price="1.0")
return Item(aliased_name="coerce", price="1.0") # ty: ignore[invalid-argument-type]
@app.get(

View File

@@ -29,7 +29,7 @@ class ModelDefaults(BaseModel):
@app.get("/", response_model=Model, response_model_exclude_unset=True)
def get_root() -> ModelSubclass:
return ModelSubclass(sub={}, y=1, z=0)
return ModelSubclass(sub={}, y=1, z=0) # ty: ignore[invalid-argument-type]
@app.get(

View File

@@ -227,7 +227,7 @@ def test_server_sent_event_single_line_fields_reject_newlines(
field_name: str, value: str
):
with pytest.raises(ValueError, match=f"SSE '{field_name}' must be a single line"):
ServerSentEvent(data="test", **{field_name: value})
ServerSentEvent(data="test", **{field_name: value}) # ty: ignore[invalid-argument-type]
def test_server_sent_event_negative_retry_rejected():
@@ -237,7 +237,7 @@ def test_server_sent_event_negative_retry_rejected():
def test_server_sent_event_float_retry_rejected():
with pytest.raises(ValueError):
ServerSentEvent(data="test", retry=1.5) # type: ignore[arg-type]
ServerSentEvent(data="test", retry=1.5) # type: ignore[arg-type] # ty: ignore[invalid-argument-type]
def test_raw_data_sent_without_json_encoding(client: TestClient):

View File

@@ -32,7 +32,7 @@ def test_route_converters_int():
response = client.get("/int/5")
assert response.status_code == 200, response.text
assert response.json() == {"int": 5}
assert app.url_path_for("int_convertor", param=5) == "/int/5" # type: ignore
assert app.url_path_for("int_convertor", param=5) == "/int/5"
def test_route_converters_float():
@@ -40,7 +40,7 @@ def test_route_converters_float():
response = client.get("/float/25.5")
assert response.status_code == 200, response.text
assert response.json() == {"float": 25.5}
assert app.url_path_for("float_convertor", param=25.5) == "/float/25.5" # type: ignore
assert app.url_path_for("float_convertor", param=25.5) == "/float/25.5"
def test_route_converters_path():

View File

@@ -10,6 +10,7 @@ import anyio
import pytest
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from starlette.types import Message, Scope
pytestmark = [
pytest.mark.anyio,
@@ -45,16 +46,16 @@ async def _run_asgi_and_cancel(app: FastAPI, path: str, timeout: float) -> bool:
"""
chunks: list[bytes] = []
async def receive(): # type: ignore[no-untyped-def]
async def receive() -> Message:
# Simulate a client that never disconnects, rely on cancellation
await anyio.sleep(float("inf"))
return {"type": "http.disconnect"} # pragma: no cover
async def send(message: dict) -> None: # type: ignore[type-arg]
async def send(message: Message) -> None:
if message["type"] == "http.response.body":
chunks.append(message.get("body", b""))
scope = {
scope: Scope = {
"type": "http",
"asgi": {"version": "3.0", "spec_version": "2.0"},
"http_version": "1.1",
@@ -67,7 +68,7 @@ async def _run_asgi_and_cancel(app: FastAPI, path: str, timeout: float) -> bool:
}
with anyio.move_on_after(timeout) as cancel_scope:
await app(scope, receive, send) # type: ignore[arg-type]
await app(scope, receive, send)
# If we got here within the timeout the generator was cancellable.
# cancel_scope.cancelled_caught is True when move_on_after fired.

View File

@@ -8,7 +8,7 @@ def test_init_oauth_html_chars_are_escaped():
title="Test",
init_oauth={"appName": xss_payload},
)
body = html.body.decode()
body = bytes(html.body).decode()
assert "</script><script>" not in body
assert "\\u003c/script\\u003e\\u003cscript\\u003e" in body
@@ -20,7 +20,7 @@ def test_swagger_ui_parameters_html_chars_are_escaped():
title="Test",
swagger_ui_parameters={"customKey": "<img src=x onerror=alert(1)>"},
)
body = html.body.decode()
body = bytes(html.body).decode()
assert "<img src=x onerror=alert(1)>" not in body
assert "\\u003cimg" in body
@@ -31,7 +31,7 @@ def test_normal_init_oauth_still_works():
title="Test",
init_oauth={"clientId": "my-client", "appName": "My App"},
)
body = html.body.decode()
body = bytes(html.body).decode()
assert '"clientId": "my-client"' in body
assert '"appName": "My App"' in body
assert "ui.initOAuth" in body

View File

@@ -157,7 +157,7 @@ def test_post_broken_body(client: TestClient):
def test_post_form_for_json(client: TestClient):
response = client.post("/items/", data={"name": "Foo", "price": 50.5})
response = client.post("/items/", data={"name": "Foo", "price": "50.5"})
assert response.status_code == 422, response.text
assert response.json() == {
"detail": [

View File

@@ -1,4 +1,5 @@
import importlib
from typing import Any
import pytest
from dirty_equals import IsList
@@ -130,7 +131,7 @@ def test_put_missing_required(client: TestClient):
def test_openapi_schema(client: TestClient, mod_name: str):
tags_schema = {"default": [], "title": "Tags"}
tags_schema: dict[str, Any] = {"default": [], "title": "Tags"}
if mod_name.startswith("tutorial001"):
tags_schema.update(UNTYPED_LIST_SCHEMA)
elif mod_name.startswith("tutorial002"):

View File

@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Any, cast
from fastapi.testclient import TestClient
@@ -10,7 +11,7 @@ client = TestClient(app)
def test_get(tmp_path: Path):
file_path: Path = tmp_path / "large-video-file.mp4"
tutorial008_py310.some_file_path = str(file_path)
cast(Any, tutorial008_py310).some_file_path = str(file_path)
test_content = b"Fake video bytes"
file_path.write_bytes(test_content)
response = client.get("/")

View File

@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Any, cast
from fastapi.testclient import TestClient
@@ -10,7 +11,7 @@ client = TestClient(app)
def test_get(tmp_path: Path):
file_path: Path = tmp_path / "large-video-file.mp4"
tutorial009_py310.some_file_path = str(file_path)
cast(Any, tutorial009_py310).some_file_path = str(file_path)
test_content = b"Fake video bytes"
file_path.write_bytes(test_content)
response = client.get("/")

View File

@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Any, cast
from fastapi.testclient import TestClient
@@ -10,7 +11,7 @@ client = TestClient(app)
def test_get(tmp_path: Path):
file_path: Path = tmp_path / "large-video-file.mp4"
tutorial009b_py310.some_file_path = str(file_path)
cast(Any, tutorial009b_py310).some_file_path = str(file_path)
test_content = b"Fake video bytes"
file_path.write_bytes(test_content)
response = client.get("/")

View File

@@ -1,7 +1,7 @@
import importlib
import runpy
import sys
import unittest
from unittest import mock
import pytest
from fastapi.testclient import TestClient
@@ -20,7 +20,7 @@ def get_client():
def test_uvicorn_run_is_not_called_on_import():
if sys.modules.get(MOD_NAME):
del sys.modules[MOD_NAME] # pragma: no cover
with unittest.mock.patch("uvicorn.run") as uvicorn_run_mock:
with mock.patch("uvicorn.run") as uvicorn_run_mock:
importlib.import_module(MOD_NAME)
uvicorn_run_mock.assert_not_called()
@@ -34,12 +34,10 @@ def test_get_root(client: TestClient):
def test_uvicorn_run_called_when_run_as_main(): # Just for coverage
if sys.modules.get(MOD_NAME):
del sys.modules[MOD_NAME]
with unittest.mock.patch("uvicorn.run") as uvicorn_run_mock:
with mock.patch("uvicorn.run") as uvicorn_run_mock:
runpy.run_module(MOD_NAME, run_name="__main__")
uvicorn_run_mock.assert_called_once_with(
unittest.mock.ANY, host="0.0.0.0", port=8000
)
uvicorn_run_mock.assert_called_once_with(mock.ANY, host="0.0.0.0", port=8000)
def test_openapi_schema(client: TestClient):

View File

@@ -1,3 +1,4 @@
from fastapi.routing import APIRoute
from fastapi.testclient import TestClient
from inline_snapshot import snapshot
@@ -14,7 +15,9 @@ def test_get():
def test_dummy_webhook():
# Just for coverage
app.webhooks.routes[0].endpoint({})
route = app.webhooks.routes[0]
assert isinstance(route, APIRoute)
route.endpoint({})
def test_openapi_schema():

View File

@@ -9,4 +9,4 @@ def test_get_name_with_age_pass_int():
def test_get_name_with_age_pass_str():
assert get_name_with_age("John", "30") == "John is this old: 30"
assert get_name_with_age("John", "30") == "John is this old: 30" # ty: ignore[invalid-argument-type]

View File

@@ -4,9 +4,9 @@ from docs_src.python_types.tutorial005_py310 import get_items
def test_get_items():
res = get_items(
"item_a",
"item_b",
"item_c",
"item_d",
"item_e",
"item_b", # ty: ignore[invalid-argument-type]
"item_c", # ty: ignore[invalid-argument-type]
"item_d", # ty: ignore[invalid-argument-type]
"item_e", # ty: ignore[invalid-argument-type]
)
assert res == ("item_a", "item_b", "item_c", "item_d", "item_e")

View File

@@ -1,6 +1,7 @@
import importlib
from functools import lru_cache
from types import ModuleType
from typing import Any, cast
import pytest
from fastapi.testclient import TestClient
@@ -29,12 +30,13 @@ def cache_verify_password(mod: ModuleType):
f"Module {mod.__name__} does not have attribute 'verify_password'"
)
original_func = mod.verify_password
mod_any = cast(Any, mod)
original_func = mod_any.verify_password
cached_func = lru_cache()(original_func)
mod.verify_password = cached_func
mod_any.verify_password = cached_func
yield
mod.verify_password = original_func
mod_any.verify_password = original_func
def get_access_token(

View File

@@ -1,5 +1,6 @@
import importlib
import warnings
from typing import Any, cast
import pytest
from dirty_equals import IsInt
@@ -35,15 +36,18 @@ def get_client(request: pytest.FixtureRequest):
mod = importlib.import_module(f"docs_src.sql_databases.{request.param}")
clear_sqlmodel()
importlib.reload(mod)
mod.sqlite_url = "sqlite://"
mod.engine = create_engine(
mod.sqlite_url, connect_args={"check_same_thread": False}, poolclass=StaticPool
mod_any = cast(Any, mod)
mod_any.sqlite_url = "sqlite://"
mod_any.engine = create_engine(
mod_any.sqlite_url,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
with TestClient(mod.app) as c:
with TestClient(mod_any.app) as c:
yield c
# Clean up connection explicitly to avoid resource warning
mod.engine.dispose()
mod_any.engine.dispose()
def test_crud_app(client: TestClient):

View File

@@ -1,5 +1,6 @@
import importlib
import warnings
from typing import Any, cast
import pytest
from dirty_equals import IsInt
@@ -35,15 +36,18 @@ def get_client(request: pytest.FixtureRequest):
mod = importlib.import_module(f"docs_src.sql_databases.{request.param}")
clear_sqlmodel()
importlib.reload(mod)
mod.sqlite_url = "sqlite://"
mod.engine = create_engine(
mod.sqlite_url, connect_args={"check_same_thread": False}, poolclass=StaticPool
mod_any = cast(Any, mod)
mod_any.sqlite_url = "sqlite://"
mod_any.engine = create_engine(
mod_any.sqlite_url,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
with TestClient(mod.app) as c:
with TestClient(mod_any.app) as c:
yield c
# Clean up connection explicitly to avoid resource warning
mod.engine.dispose()
mod_any.engine.dispose()
def test_crud_app(client: TestClient):

View File

@@ -33,7 +33,10 @@ client = TestClient(app)
def test_dummy_webhook():
# Just for coverage
new_subscription(body={}, token="Bearer 123")
new_subscription(
body=Subscription(username="rick", monthly_fee=9.99, start_date=datetime.now()),
token="Bearer 123",
)
def test_openapi_schema():

View File

@@ -1,5 +1,5 @@
import importlib
import sys
from importlib.util import find_spec
import pytest
@@ -11,12 +11,12 @@ needs_py314 = pytest.mark.skipif(
)
needs_orjson = pytest.mark.skipif(
importlib.util.find_spec("orjson") is None,
find_spec("orjson") is None,
reason="requires orjson",
)
needs_ujson = pytest.mark.skipif(
importlib.util.find_spec("ujson") is None,
find_spec("ujson") is None,
reason="requires ujson",
)