diff --git a/tests/test_request_params/test_body/test_nullable_and_defaults.py b/tests/test_request_params/test_body/test_nullable_and_defaults.py index 0e032e8419..7696f765fd 100644 --- a/tests/test_request_params/test_body/test_nullable_and_defaults.py +++ b/tests/test_request_params/test_body/test_nullable_and_defaults.py @@ -1,25 +1,34 @@ from typing import Annotated, Any, Union +from unittest.mock import Mock, patch import pytest from dirty_equals import IsList, IsOneOf from fastapi import Body, FastAPI from fastapi.testclient import TestClient -from pydantic import BaseModel +from pydantic import BaseModel, BeforeValidator, field_validator from .utils import get_body_model_name app = FastAPI() +def convert(v: Any) -> Any: + return v + + # ===================================================================================== # Nullable required @app.post("/nullable-required") async def read_nullable_required( - int_val: Annotated[Union[int, None], Body()], - str_val: Annotated[Union[str, None], Body()], - list_val: Union[list[int], None], + int_val: Annotated[Union[int, None], Body(), BeforeValidator(lambda v: convert(v))], + str_val: Annotated[Union[str, None], Body(), BeforeValidator(lambda v: convert(v))], + list_val: Annotated[ + Union[list[int], None], + Body(), + BeforeValidator(lambda v: convert(v)), + ], ): return { "int_val": int_val, @@ -34,6 +43,10 @@ class ModelNullableRequired(BaseModel): str_val: Union[str, None] list_val: Union[list[int], None] + @field_validator("*", mode="before") + def validate_all(cls, v): + return convert(v) + @app.post("/model-nullable-required") async def read_model_nullable_required(params: ModelNullableRequired): @@ -47,21 +60,23 @@ async def read_model_nullable_required(params: ModelNullableRequired): @app.post("/nullable-required-str") async def read_nullable_required_no_embed_str( - str_val: Annotated[Union[str, None], Body()], + str_val: Annotated[Union[str, None], Body(), BeforeValidator(lambda v: convert(v))], ): return {"val": str_val} @app.post("/nullable-required-int") async def read_nullable_required_no_embed_int( - int_val: Annotated[Union[int, None], Body()], + int_val: Annotated[Union[int, None], Body(), BeforeValidator(lambda v: convert(v))], ): return {"val": int_val} @app.post("/nullable-required-list") async def read_nullable_required_no_embed_list( - list_val: Annotated[Union[list[int], None], Body()], + list_val: Annotated[ + Union[list[int], None], Body(), BeforeValidator(lambda v: convert(v)) + ], ): return {"val": list_val} @@ -278,14 +293,18 @@ def test_nullable_required_pass_empty_dict(path: str, msg: str, error_type: str) ) def test_nullable_required_pass_null(path: str): client = TestClient(app) - response = client.post( - path, - json={ - "int_val": None, - "str_val": None, - "list_val": None, - }, - ) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + json={ + "int_val": None, + "str_val": None, + "list_val": None, + }, + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" assert response.status_code == 200, response.text assert response.json() == { "int_val": None, @@ -308,10 +327,13 @@ def test_nullable_required_pass_null(path: str): @pytest.mark.xfail(reason="Explicit null-body is treated as missing") def test_nullable_required_no_embed_pass_null(path: str): client = TestClient(app) - response = client.post(path, content="null") + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, content="null") + + assert mock_convert.call_count == 1, "Validator should be called once for the field" assert response.status_code == 200, response.text assert response.json() == {"val": None} - # TODO: add test with BeforeValidator to ensure that it recieves `None` value @pytest.mark.parametrize( @@ -323,9 +345,13 @@ def test_nullable_required_no_embed_pass_null(path: str): ) def test_nullable_required_pass_value(path: str): client = TestClient(app) - response = client.post( - path, json={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} - ) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, json={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" assert response.status_code == 200, response.text assert response.json() == { "int_val": 1, @@ -347,10 +373,11 @@ def test_nullable_required_pass_value(path: str): ) def test_nullable_required_no_embed_pass_value(path: str, value: Any): client = TestClient(app) - response = client.post( - path, - json=value, - ) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, json=value) + + assert mock_convert.call_count == 1, "Validator should be called once for the field" assert response.status_code == 200, response.text assert response.json() == {"val": value} @@ -361,9 +388,21 @@ def test_nullable_required_no_embed_pass_value(path: str, value: Any): @app.post("/nullable-non-required") async def read_nullable_non_required( - int_val: Annotated[Union[int, None], Body()] = None, - str_val: Annotated[Union[str, None], Body()] = None, - list_val: Union[list[int], None] = None, + int_val: Annotated[ + Union[int, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = None, + str_val: Annotated[ + Union[str, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = None, + list_val: Annotated[ + Union[list[int], None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = None, ): return { "int_val": int_val, @@ -378,6 +417,10 @@ class ModelNullableNonRequired(BaseModel): str_val: Union[str, None] = None list_val: Union[list[int], None] = None + @field_validator("*", mode="before") + def validate_all(cls, v): + return convert(v) + @app.post("/model-nullable-non-required") async def read_model_nullable_non_required( @@ -393,21 +436,33 @@ async def read_model_nullable_non_required( @app.post("/nullable-non-required-str") async def read_nullable_non_required_no_embed_str( - str_val: Annotated[Union[str, None], Body()] = None, + str_val: Annotated[ + Union[str, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = None, ): return {"val": str_val} @app.post("/nullable-non-required-int") async def read_nullable_non_required_no_embed_int( - int_val: Annotated[Union[int, None], Body()] = None, + int_val: Annotated[ + Union[int, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = None, ): return {"val": int_val} @app.post("/nullable-non-required-list") async def read_nullable_non_required_no_embed_list( - list_val: Annotated[Union[list[int], None], Body()] = None, + list_val: Annotated[ + Union[list[int], None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = None, ): return {"val": list_val} @@ -499,7 +554,13 @@ def test_nullable_non_required_no_embed_schema(path: str, schema: dict): ) def test_nullable_non_required_missing(path: str): client = TestClient(app) - response = client.post(path, json={}) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, json={}) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) assert response.status_code == 200 assert response.json() == { "int_val": None, @@ -547,7 +608,13 @@ def test_nullable_non_required_no_body(path: str): ) def test_nullable_non_required_no_embed_missing(path: str): client = TestClient(app) - response = client.post(path) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) assert response.status_code == 200 assert response.json() == {"val": None} @@ -555,20 +622,29 @@ def test_nullable_non_required_no_embed_missing(path: str): @pytest.mark.parametrize( "path", [ - "/nullable-non-required", + pytest.param( + "/nullable-non-required", + marks=pytest.mark.xfail( + reason="Null values are treated as missing for non-model Body parameters" + ), + ), "/model-nullable-non-required", ], ) def test_nullable_non_required_pass_null(path: str): client = TestClient(app) - response = client.post( - path, - json={ - "int_val": None, - "str_val": None, - "list_val": None, - }, - ) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + json={ + "int_val": None, + "str_val": None, + "list_val": None, + }, + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" assert response.status_code == 200, response.text assert response.json() == { "int_val": None, @@ -588,12 +664,16 @@ def test_nullable_non_required_pass_null(path: str): "/nullable-non-required-list", ], ) +@pytest.mark.xfail(reason="Explicit null-body is treated as missing") def test_nullable_non_required_no_embed_pass_null(path: str): client = TestClient(app) - response = client.post(path, content="null") + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, content="null") + + assert mock_convert.call_count == 1, "Validator should be called once for the field" assert response.status_code == 200, response.text assert response.json() == {"val": None} - # TODO: add test with BeforeValidator to ensure that it recieves `None` value @pytest.mark.parametrize( @@ -605,9 +685,13 @@ def test_nullable_non_required_no_embed_pass_null(path: str): ) def test_nullable_non_required_pass_value(path: str): client = TestClient(app) - response = client.post( - path, json={"int_val": 1, "str_val": "test", "list_val": [1, 2]} - ) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, json={"int_val": 1, "str_val": "test", "list_val": [1, 2]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" assert response.status_code == 200, response.text assert response.json() == { "int_val": 1, @@ -629,7 +713,11 @@ def test_nullable_non_required_pass_value(path: str): ) def test_nullable_non_required_no_embed_pass_value(path: str, value: Any): client = TestClient(app) - response = client.post(path, json=value) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, json=value) + + assert mock_convert.call_count == 1, "Validator should be called once for the field" assert response.status_code == 200, response.text assert response.json() == {"val": value} @@ -641,9 +729,21 @@ def test_nullable_non_required_no_embed_pass_value(path: str, value: Any): @app.post("/nullable-with-non-null-default") async def read_nullable_with_non_null_default( *, - int_val: Annotated[Union[int, None], Body()] = -1, - str_val: Annotated[Union[str, None], Body()] = "default", - list_val: Annotated[Union[list[int], None], Body(default_factory=lambda: [0])], + int_val: Annotated[ + Union[int, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = -1, + str_val: Annotated[ + Union[str, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = "default", + list_val: Annotated[ + Union[list[int], None], + Body(default_factory=lambda: [0]), + BeforeValidator(lambda v: convert(v)), + ], ): return { "int_val": int_val, @@ -658,6 +758,10 @@ class ModelNullableWithNonNullDefault(BaseModel): str_val: Union[str, None] = "default" list_val: Union[list[int], None] = [0] + @field_validator("*", mode="before") + def validate_all(cls, v): + return convert(v) + @app.post("/model-nullable-with-non-null-default") async def read_model_nullable_with_non_null_default( @@ -673,21 +777,33 @@ async def read_model_nullable_with_non_null_default( @app.post("/nullable-with-non-null-default-str") async def read_nullable_with_non_null_default_no_embed_str( - str_val: Annotated[Union[str, None], Body()] = "default", + str_val: Annotated[ + Union[str, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = "default", ): return {"val": str_val} @app.post("/nullable-with-non-null-default-int") async def read_nullable_with_non_null_default_no_embed_int( - int_val: Annotated[Union[int, None], Body()] = -1, + int_val: Annotated[ + Union[int, None], + Body(), + BeforeValidator(lambda v: convert(v)), + ] = -1, ): return {"val": int_val} @app.post("/nullable-with-non-null-default-list") async def read_nullable_with_non_null_default_no_embed_list( - list_val: Annotated[Union[list[int], None], Body(default_factory=lambda: [0])], + list_val: Annotated[ + Union[list[int], None], + Body(default_factory=lambda: [0]), + BeforeValidator(lambda v: convert(v)), + ], ): return {"val": list_val} @@ -787,7 +903,13 @@ def test_nullable_with_non_null_default_no_embed_schema(path: str, schema: dict) ) def test_nullable_with_non_null_default_missing(path: str): client = TestClient(app) - response = client.post(path, json={}) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, json={}) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) assert response.status_code == 200, response.text assert response.json() == { "int_val": -1, @@ -835,7 +957,13 @@ def test_nullable_with_non_null_default_no_body(path: str): ) def test_nullable_with_non_null_default_no_embed_missing(path: str, expected: Any): client = TestClient(app) - response = client.post(path) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path) + + assert mock_convert.call_count == 0, ( + "Validator should not be called if the value is missing" + ) assert response.status_code == 200, response.text assert response.json() == {"val": expected} @@ -854,14 +982,18 @@ def test_nullable_with_non_null_default_no_embed_missing(path: str, expected: An ) def test_nullable_with_non_null_default_pass_null(path: str): client = TestClient(app) - response = client.post( - path, - json={ - "int_val": None, - "str_val": None, - "list_val": None, - }, - ) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, + json={ + "int_val": None, + "str_val": None, + "list_val": None, + }, + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" assert response.status_code == 200, response.text assert response.json() == { "int_val": None, @@ -884,7 +1016,11 @@ def test_nullable_with_non_null_default_pass_null(path: str): @pytest.mark.xfail(reason="Explicit null-body is treated as missing") def test_nullable_with_non_null_default_no_embed_pass_null(path: str): client = TestClient(app) - response = client.post(path, content="null") + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, content="null") + + assert mock_convert.call_count == 1, "Validator should be called once for the field" assert response.status_code == 200, response.text assert response.json() == {"val": None} @@ -898,9 +1034,13 @@ def test_nullable_with_non_null_default_no_embed_pass_null(path: str): ) def test_nullable_with_non_null_default_pass_value(path: str): client = TestClient(app) - response = client.post( - path, json={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} - ) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post( + path, json={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]} + ) + + assert mock_convert.call_count == 3, "Validator should be called for each field" assert response.status_code == 200, response.text assert response.json() == { "int_val": 1, @@ -922,6 +1062,10 @@ def test_nullable_with_non_null_default_pass_value(path: str): ) def test_nullable_with_non_null_default_no_embed_pass_value(path: str, value: Any): client = TestClient(app) - response = client.post(path, json=value) + + with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert: + response = client.post(path, json=value) + + assert mock_convert.call_count == 1, "Validator should be called once for the field" assert response.status_code == 200, response.text assert response.json() == {"val": value}