diff --git a/tests/test_request_params/test_query/test_nullable_and_defaults.py b/tests/test_request_params/test_query/test_nullable_and_defaults.py new file mode 100644 index 0000000000..bca69444d0 --- /dev/null +++ b/tests/test_request_params/test_query/test_nullable_and_defaults.py @@ -0,0 +1,480 @@ +from typing import Annotated, Any, Union +from unittest.mock import Mock, patch + +import pytest +from dirty_equals import IsList, IsOneOf +from fastapi import FastAPI, Query +from fastapi.testclient import TestClient +from pydantic import BaseModel, BeforeValidator, field_validator + +app = FastAPI() + + +def convert(v: Any) -> Any: + return v + + +# ===================================================================================== +# Nullable required + + +@app.get("/nullable-required") +async def read_nullable_required( + int_val: Annotated[ + Union[int, None], + BeforeValidator(lambda v: convert(v)), + ], + str_val: Annotated[ + Union[str, None], + BeforeValidator(lambda v: convert(v)), + ], + list_val: Annotated[ + Union[list[int], None], + Query(), + BeforeValidator(lambda v: convert(v)), + ], +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableRequired(BaseModel): + int_val: Union[int, None] + str_val: Union[str, None] + list_val: Union[list[int], None] + + @field_validator("*", mode="before") + @classmethod + def convert_all(cls, v: Any) -> Any: + return convert(v) + + +@app.get("/model-nullable-required") +async def read_model_nullable_required( + params: Annotated[ModelNullableRequired, Query()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_schema(path: str): + assert app.openapi()["paths"][path]["get"]["parameters"] == [ + { + "required": True, + "schema": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + }, + "name": "int_val", + "in": "query", + }, + { + "required": True, + "schema": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + }, + "name": "str_val", + "in": "query", + }, + { + "in": "query", + "name": "list_val", + "required": True, + "schema": { + "anyOf": [ + {"items": {"type": "integer"}, "type": "array"}, + {"type": "null"}, + ], + "title": "List Val", + }, + }, + ] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "missing", + "loc": ["query", "int_val"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + { + "type": "missing", + "loc": ["query", "str_val"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + { + "type": "missing", + "loc": ["query", "list_val"], + "msg": "Field required", + "input": IsOneOf(None, {}), + }, + ] + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-required", + "/model-nullable-required", + ], +) +def test_nullable_required_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get( + path, params={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + + +# ===================================================================================== +# Nullable with default=None + + +@app.get("/nullable-non-required") +async def read_nullable_non_required( + int_val: Annotated[ + Union[int, None], + BeforeValidator(lambda v: convert(v)), + ] = None, + str_val: Annotated[ + Union[str, None], + BeforeValidator(lambda v: convert(v)), + ] = None, + list_val: Annotated[ + Union[list[int], None], + Query(), + BeforeValidator(lambda v: convert(v)), + ] = None, +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableNonRequired(BaseModel): + int_val: Union[int, None] = None + str_val: Union[str, None] = None + list_val: Union[list[int], None] = None + + @field_validator("*", mode="before") + @classmethod + def convert_all(cls, v: Any) -> Any: + return convert(v) + + +@app.get("/model-nullable-non-required") +async def read_model_nullable_non_required( + params: Annotated[ModelNullableNonRequired, Query()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_schema(path: str): + assert app.openapi()["paths"][path]["get"]["parameters"] == [ + { + "required": False, + "schema": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + "name": "int_val", + "in": "query", + }, + { + "required": False, + "schema": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + # "default": None, # `None` values are omitted in OpenAPI schema + }, + "name": "str_val", + "in": "query", + }, + { + "in": "query", + "name": "list_val", + "required": False, + "schema": { + "anyOf": [ + {"items": {"type": "integer"}, "type": "array"}, + {"type": "null"}, + ], + "title": "List Val", + # "default": None, # `None` values are omitted in OpenAPI schema + }, + }, + ] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200 + assert response.json() == { + "int_val": None, + "str_val": None, + "list_val": None, + "fields_set": IsOneOf(None, []), + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-non-required", + "/model-nullable-non-required", + ], +) +def test_nullable_non_required_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get( + path, params={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + } + + +# ===================================================================================== +# Nullable with not-None default + + +@app.get("/nullable-with-non-null-default") +async def read_nullable_with_non_null_default( + *, + int_val: Annotated[ + Union[int, None], + BeforeValidator(lambda v: convert(v)), + ] = -1, + str_val: Annotated[ + Union[str, None], + BeforeValidator(lambda v: convert(v)), + ] = "default", + list_val: Annotated[ + Union[list[int], None], + Query(default_factory=lambda: [0]), + BeforeValidator(lambda v: convert(v)), + ], +): + return { + "int_val": int_val, + "str_val": str_val, + "list_val": list_val, + "fields_set": None, + } + + +class ModelNullableWithNonNullDefault(BaseModel): + int_val: Union[int, None] = -1 + str_val: Union[str, None] = "default" + list_val: Union[list[int], None] = [0] + + @field_validator("*", mode="before") + @classmethod + def convert_all(cls, v: Any) -> Any: + return convert(v) + + +@app.get("/model-nullable-with-non-null-default") +async def read_model_nullable_with_non_null_default( + params: Annotated[ModelNullableWithNonNullDefault, Query()], +): + return { + "int_val": params.int_val, + "str_val": params.str_val, + "list_val": params.list_val, + "fields_set": params.model_fields_set, + } + + +@pytest.mark.parametrize( + "path", + [ + pytest.param( + "/nullable-with-non-null-default", + marks=pytest.mark.xfail( + reason="`default_factory` is not reflected in OpenAPI schema" + ), + ), + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_schema(path: str): + assert app.openapi()["paths"][path]["get"]["parameters"] == [ + { + "required": False, + "schema": { + "title": "Int Val", + "anyOf": [{"type": "integer"}, {"type": "null"}], + "default": -1, + }, + "name": "int_val", + "in": "query", + }, + { + "required": False, + "schema": { + "title": "Str Val", + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": "default", + }, + "name": "str_val", + "in": "query", + }, + { + "in": "query", + "name": "list_val", + "required": False, + "schema": { + "anyOf": [ + {"items": {"type": "integer"}, "type": "array"}, + {"type": "null"}, + ], + "title": "List Val", + "default": [0], # `default_factory` is not reflected in OpenAPI schema + }, + }, + ] + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +@pytest.mark.xfail( + reason="Missing parameters are pre-populated with default values before validation" +) +def test_nullable_with_non_null_default_missing(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) + assert response.status_code == 200 + assert response.json() == { + "int_val": -1, + "str_val": "default", + "list_val": [0], + "fields_set": IsOneOf(None, []), + } + + +@pytest.mark.parametrize( + "path", + [ + "/nullable-with-non-null-default", + "/model-nullable-with-non-null-default", + ], +) +def test_nullable_with_non_null_default_pass_value(path: str): + client = TestClient(app) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.get( + path, params={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" + assert response.status_code == 200, response.text + assert response.json() == { + "int_val": 1, + "str_val": "test", + "list_val": [1, 2], + "fields_set": IsOneOf( + None, IsList("int_val", "str_val", "list_val", check_order=False) + ), + }