mirror of
https://github.com/fastapi/fastapi.git
synced 2026-02-26 03:36:14 -05:00
Compare commits
21 Commits
0.133.1
...
add-tests-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
90ebf74f68 | ||
|
|
7e96b8b8fa | ||
|
|
1a251c63c2 | ||
|
|
d10fa5df11 | ||
|
|
c49c90efc0 | ||
|
|
f49f65aa16 | ||
|
|
c733bab825 | ||
|
|
bfc09d9440 | ||
|
|
2a2aafa01e | ||
|
|
cf0d31bd69 | ||
|
|
7fed2671c4 | ||
|
|
d90bcc8569 | ||
|
|
27cc340880 | ||
|
|
9e85c19d3a | ||
|
|
3441e14197 | ||
|
|
e1adc4a739 | ||
|
|
e6475e960a | ||
|
|
2d43382626 | ||
|
|
0b5fea716b | ||
|
|
22d795d890 | ||
|
|
dada1d581f |
@@ -7,13 +7,6 @@ hide:
|
||||
|
||||
## Latest Changes
|
||||
|
||||
## 0.133.1
|
||||
|
||||
### Features
|
||||
|
||||
* 🔧 Add FastAPI Agents Skill. PR [#14982](https://github.com/fastapi/fastapi/pull/14982) by [@tiangolo](https://github.com/tiangolo).
|
||||
* Read more about it in [Library Agent Skills](https://tiangolo.com/ideas/library-agent-skills/).
|
||||
|
||||
### Internal
|
||||
|
||||
* ✅ Fix all tests are skipped on Windows. PR [#14994](https://github.com/fastapi/fastapi/pull/14994) by [@YuriiMotov](https://github.com/YuriiMotov).
|
||||
|
||||
@@ -1,614 +0,0 @@
|
||||
---
|
||||
name: fastapi
|
||||
description: FastAPI best practices and conventions. Use when working with FastAPI APIs and Pydantic models for them. Keeps FastAPI code clean and up to date with the latest features and patterns, updated with new versions. Write new code or refactor and update old code.
|
||||
---
|
||||
|
||||
# FastAPI
|
||||
|
||||
Official FastAPI skill to write code with best practices, keeping up to date with new versions and features.
|
||||
|
||||
## Use the `fastapi` CLI
|
||||
|
||||
Run the development server on localhost with reload:
|
||||
|
||||
```bash
|
||||
fastapi dev
|
||||
```
|
||||
|
||||
|
||||
Run the production server:
|
||||
|
||||
```bash
|
||||
fastapi run
|
||||
```
|
||||
|
||||
### Add an entrypoint in `pyproject.toml`
|
||||
|
||||
FastAPI CLI will read the entrypoint in `pyproject.toml` to know where the FastAPI app is declared.
|
||||
|
||||
```toml
|
||||
[tool.fastapi]
|
||||
entrypoint = "my_app.main:app"
|
||||
```
|
||||
|
||||
### Use `fastapi` with a path
|
||||
|
||||
When adding the entrypoint to `pyproject.toml` is not possible, or the user explicitly asks not to, or it's running an independent small app, you can pass the app file path to the `fastapi` command:
|
||||
|
||||
```bash
|
||||
fastapi dev my_app/main.py
|
||||
```
|
||||
|
||||
Prefer to set the entrypoint in `pyproject.toml` when possible.
|
||||
|
||||
## Use `Annotated`
|
||||
|
||||
Always prefer the `Annotated` style for parameter and dependency declarations.
|
||||
|
||||
It keeps the function signatures working in other contexts, respects the types, allows reusability.
|
||||
|
||||
### In Parameter Declarations
|
||||
|
||||
Use `Annotated` for parameter declarations, including `Path`, `Query`, `Header`, etc.:
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import FastAPI, Path, Query
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/items/{item_id}")
|
||||
async def read_item(
|
||||
item_id: Annotated[int, Path(ge=1, description="The item ID")],
|
||||
q: Annotated[str | None, Query(max_length=50)] = None,
|
||||
):
|
||||
return {"message": "Hello World"}
|
||||
```
|
||||
|
||||
instead of:
|
||||
|
||||
```python
|
||||
# DO NOT DO THIS
|
||||
@app.get("/items/{item_id}")
|
||||
async def read_item(
|
||||
item_id: int = Path(ge=1, description="The item ID"),
|
||||
q: str | None = Query(default=None, max_length=50),
|
||||
):
|
||||
return {"message": "Hello World"}
|
||||
```
|
||||
|
||||
### For Dependencies
|
||||
|
||||
Use `Annotated` for dependencies with `Depends()`.
|
||||
|
||||
Unless asked not to, create a new type alias for the dependency to allow re-using it.
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def get_current_user():
|
||||
return {"username": "johndoe"}
|
||||
|
||||
|
||||
CurrentUserDep = Annotated[dict, Depends(get_current_user)]
|
||||
|
||||
|
||||
@app.get("/items/")
|
||||
async def read_item(current_user: CurrentUserDep):
|
||||
return {"message": "Hello World"}
|
||||
```
|
||||
|
||||
instead of:
|
||||
|
||||
```python
|
||||
# DO NOT DO THIS
|
||||
@app.get("/items/")
|
||||
async def read_item(current_user: dict = Depends(get_current_user)):
|
||||
return {"message": "Hello World"}
|
||||
```
|
||||
|
||||
## Do not use Ellipsis for *path operations* or Pydantic models
|
||||
|
||||
Do not use `...` as a default value for required parameters, it's not needed and not recommended.
|
||||
|
||||
Do this, without Ellipsis (`...`):
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import FastAPI, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
price: float = Field(gt=0)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.post("/items/")
|
||||
async def create_item(item: Item, project_id: Annotated[int, Query()]): ...
|
||||
```
|
||||
|
||||
instead of this:
|
||||
|
||||
```python
|
||||
# DO NOT DO THIS
|
||||
class Item(BaseModel):
|
||||
name: str = ...
|
||||
description: str | None = None
|
||||
price: float = Field(..., gt=0)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.post("/items/")
|
||||
async def create_item(item: Item, project_id: Annotated[int, Query(...)]): ...
|
||||
```
|
||||
|
||||
## Return Type or Response Model
|
||||
|
||||
When possible, include a return type. It will be used to validate, filter, document, and serialize the response.
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
@app.get("/items/me")
|
||||
async def get_item() -> Item:
|
||||
return Item(name="Plumbus", description="All-purpose home device")
|
||||
```
|
||||
|
||||
**Important**: Return types or response models are what filter data ensuring no sensitive information is exposed. And they are used to serialize data with Pydantic (in Rust), this is the main idea that can increase response performance.
|
||||
|
||||
The return type doesn't have to be a Pydantic model, it could be a different type, like a list of integers, or a dict, etc.
|
||||
|
||||
### When to use `response_model` instead
|
||||
|
||||
If the return type is not the same as the type that you want to use to validate, filter, or serialize, use the `response_model` parameter on the decorator instead.
|
||||
|
||||
```python
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
@app.get("/items/me", response_model=Item)
|
||||
async def get_item() -> Any:
|
||||
return {"name": "Foo", "description": "A very nice Item"}
|
||||
```
|
||||
|
||||
This can be particularly useful when filtering data to expose only the public fields and avoid exposing sensitive information.
|
||||
|
||||
```python
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class InternalItem(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
secret_key: str
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
@app.get("/items/me", response_model=Item)
|
||||
async def get_item() -> Any:
|
||||
item = InternalItem(
|
||||
name="Foo", description="A very nice Item", secret_key="supersecret"
|
||||
)
|
||||
return item
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
Do not use `ORJSONResponse` or `UJSONResponse`, they are deprecated.
|
||||
|
||||
Instead, declare a return type or response model. Pydantic will handle the data serialization on the Rust side.
|
||||
|
||||
## Including Routers
|
||||
|
||||
When declaring routers, prefer to add router level parameters like prefix, tags, etc. to the router itself, instead of in `include_router()`.
|
||||
|
||||
Do this:
|
||||
|
||||
```python
|
||||
from fastapi import APIRouter, FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
router = APIRouter(prefix="/items", tags=["items"])
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_items():
|
||||
return []
|
||||
|
||||
|
||||
# In main.py
|
||||
app.include_router(router)
|
||||
```
|
||||
|
||||
instead of this:
|
||||
|
||||
```python
|
||||
# DO NOT DO THIS
|
||||
from fastapi import APIRouter, FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_items():
|
||||
return []
|
||||
|
||||
|
||||
# In main.py
|
||||
app.include_router(router, prefix="/items", tags=["items"])
|
||||
```
|
||||
|
||||
There could be exceptions, but try to follow this convention.
|
||||
|
||||
Apply shared dependencies at the router level via `dependencies=[Depends(...)]`.
|
||||
|
||||
## Dependency Injection
|
||||
|
||||
Use dependencies when:
|
||||
|
||||
* They can't be declared in Pydantic validation and require additional logic
|
||||
* The logic depends on external resources or could block in any other way
|
||||
* Other dependencies need their results (it's a sub-dependency)
|
||||
* The logic can be shared by multiple endpoints to do things like error early, authentication, etc.
|
||||
* They need to handle cleanup (e.g., DB sessions, file handles), using dependencies with `yield`
|
||||
* Their logic needs input data from the request, like headers, query parameters, etc.
|
||||
|
||||
### Dependencies with `yield` and `scope`
|
||||
|
||||
When using dependencies with `yield`, they can have a `scope` that defines when the exit code is run.
|
||||
|
||||
Use the default scope `"request"` to run the exit code after the response is sent back.
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def get_db():
|
||||
db = DBSession()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
DBDep = Annotated[DBSession, Depends(get_db)]
|
||||
|
||||
|
||||
@app.get("/items/")
|
||||
async def read_items(db: DBDep):
|
||||
return db.query(Item).all()
|
||||
```
|
||||
|
||||
Use the scope `"function"` when they should run the exit code after the response data is generated but before the response is sent back to the client.
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def get_username():
|
||||
try:
|
||||
yield "Rick"
|
||||
finally:
|
||||
print("Cleanup up before response is sent")
|
||||
|
||||
UserNameDep = Annotated[str, Depends(get_username, scope="function")]
|
||||
|
||||
@app.get("/users/me")
|
||||
def get_user_me(username: UserNameDep):
|
||||
return username
|
||||
```
|
||||
|
||||
### Class Dependencies
|
||||
|
||||
Avoid creating class dependencies when possible.
|
||||
|
||||
If a class is needed, instead create a regular function dependency that returns a class instance.
|
||||
|
||||
Do this:
|
||||
|
||||
```python
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabasePaginator:
|
||||
offset: int = 0
|
||||
limit: int = 100
|
||||
q: str | None = None
|
||||
|
||||
def get_page(self) -> dict:
|
||||
# Simulate a page of data
|
||||
return {
|
||||
"offset": self.offset,
|
||||
"limit": self.limit,
|
||||
"q": self.q,
|
||||
"items": [],
|
||||
}
|
||||
|
||||
|
||||
def get_db_paginator(
|
||||
offset: int = 0, limit: int = 100, q: str | None = None
|
||||
) -> DatabasePaginator:
|
||||
return DatabasePaginator(offset=offset, limit=limit, q=q)
|
||||
|
||||
|
||||
PaginatorDep = Annotated[DatabasePaginator, Depends(get_db_paginator)]
|
||||
|
||||
|
||||
@app.get("/items/")
|
||||
async def read_items(paginator: PaginatorDep):
|
||||
return paginator.get_page()
|
||||
```
|
||||
|
||||
instead of this:
|
||||
|
||||
```python
|
||||
# DO NOT DO THIS
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class DatabasePaginator:
|
||||
def __init__(self, offset: int = 0, limit: int = 100, q: str | None = None):
|
||||
self.offset = offset
|
||||
self.limit = limit
|
||||
self.q = q
|
||||
|
||||
def get_page(self) -> dict:
|
||||
# Simulate a page of data
|
||||
return {
|
||||
"offset": self.offset,
|
||||
"limit": self.limit,
|
||||
"q": self.q,
|
||||
"items": [],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/items/")
|
||||
async def read_items(paginator: Annotated[DatabasePaginator, Depends()]):
|
||||
return paginator.get_page()
|
||||
```
|
||||
|
||||
## Async vs Sync *path operations*
|
||||
|
||||
Use `async` *path operations* only when fully certain that the logic called inside is compatible with async and await (it's called with `await`) or that doesn't block.
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# Use async def when calling async code
|
||||
@app.get("/async-items/")
|
||||
async def read_async_items():
|
||||
data = await some_async_library.fetch_items()
|
||||
return data
|
||||
|
||||
|
||||
# Use plain def when calling blocking/sync code or when in doubt
|
||||
@app.get("/items/")
|
||||
def read_items():
|
||||
data = some_blocking_library.fetch_items()
|
||||
return data
|
||||
```
|
||||
|
||||
In case of doubt, or by default, use regular `def` functions, those will be run in a threadpool so they don't block the event loop.
|
||||
|
||||
The same rules apply to dependencies.
|
||||
|
||||
Make sure blocking code is not run inside of `async` functions. The logic will work, but will damage the performance heavily.
|
||||
|
||||
### Asyncer
|
||||
|
||||
When needing to run blocking code inside of async functions, or async code inside of blocking functions, suggest using Asyncer.
|
||||
|
||||
Install:
|
||||
|
||||
```bash
|
||||
uv add asyncer
|
||||
```
|
||||
|
||||
Run blocking sync code inside of async with `asyncify()`:
|
||||
|
||||
```python
|
||||
from asyncer import asyncify
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def do_blocking_work(name: str) -> str:
|
||||
# Some blocking I/O operation
|
||||
return f"Hello {name}"
|
||||
|
||||
|
||||
@app.get("/items/")
|
||||
async def read_items():
|
||||
result = await asyncify(do_blocking_work)(name="World")
|
||||
return {"message": result}
|
||||
```
|
||||
|
||||
And run async code inside of blocking sync code with `syncify()`:
|
||||
|
||||
```python
|
||||
from asyncer import syncify
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
async def do_async_work(name: str) -> str:
|
||||
return f"Hello {name}"
|
||||
|
||||
|
||||
@app.get("/items/")
|
||||
def read_items():
|
||||
result = syncify(do_async_work)(name="World")
|
||||
return {"message": result}
|
||||
```
|
||||
|
||||
## Use uv, ruff, ty
|
||||
|
||||
If uv is available, use it to manage dependencies.
|
||||
|
||||
If Ruff is available, use it to lint and format the code. Consider enabling the FastAPI rules.
|
||||
|
||||
If ty is available, use it to check types.
|
||||
|
||||
## SQLModel for SQL databases
|
||||
|
||||
When working with SQL databases, prefer using SQLModel as it is integrated with Pydantic and will allow declaring data validation with the same models.
|
||||
|
||||
## Do not use Pydantic RootModels
|
||||
|
||||
Do not use Pydantic `RootModel`, instead use regular type annotations with `Annotated` and Pydantic validation utilities.
|
||||
|
||||
For example, for a list with validations you could do:
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Body, FastAPI
|
||||
from pydantic import Field
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.post("/items/")
|
||||
async def create_items(items: Annotated[list[int], Field(min_length=1), Body()]):
|
||||
return items
|
||||
```
|
||||
|
||||
instead of:
|
||||
|
||||
```python
|
||||
# DO NOT DO THIS
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic import Field, RootModel
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class ItemList(RootModel[Annotated[list[int], Field(min_length=1)]]):
|
||||
pass
|
||||
|
||||
|
||||
@app.post("/items/")
|
||||
async def create_items(items: ItemList):
|
||||
return items
|
||||
|
||||
```
|
||||
|
||||
FastAPI supports these type annotations and will create a Pydantic `TypeAdapter` for them, so that types can work as normally and there's no need for the custom logic and types in RootModels.
|
||||
|
||||
## Use one HTTP operation per function
|
||||
|
||||
Don't mix HTTP operations in a single function, having one function per HTTP operation helps separate concerns and organize the code.
|
||||
|
||||
Do this:
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@app.get("/items/")
|
||||
async def list_items():
|
||||
return []
|
||||
|
||||
|
||||
@app.post("/items/")
|
||||
async def create_item(item: Item):
|
||||
return item
|
||||
```
|
||||
|
||||
instead of this:
|
||||
|
||||
```python
|
||||
# DO NOT DO THIS
|
||||
from fastapi import FastAPI, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@app.api_route("/items/", methods=["GET", "POST"])
|
||||
async def handle_items(request: Request):
|
||||
if request.method == "GET":
|
||||
return []
|
||||
```
|
||||
@@ -1,6 +1,6 @@
|
||||
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
|
||||
|
||||
__version__ = "0.133.1"
|
||||
__version__ = "0.133.0"
|
||||
|
||||
from starlette import status as status
|
||||
|
||||
|
||||
1111
tests/test_request_params/test_body/test_nullable_and_defaults.py
Normal file
1111
tests/test_request_params/test_body/test_nullable_and_defaults.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,431 @@
|
||||
from typing import Annotated, Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from dirty_equals import IsList, IsOneOf
|
||||
from fastapi import Cookie, FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from inline_snapshot import snapshot
|
||||
from pydantic import BaseModel, BeforeValidator, field_validator
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def convert(v: Any) -> Any:
|
||||
return v
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable required
|
||||
|
||||
|
||||
@app.get("/nullable-required")
|
||||
async def read_nullable_required(
|
||||
int_val: Annotated[
|
||||
int | None,
|
||||
Cookie(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
str_val: Annotated[
|
||||
str | None,
|
||||
Cookie(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
):
|
||||
return {
|
||||
"int_val": int_val,
|
||||
"str_val": str_val,
|
||||
"fields_set": None,
|
||||
}
|
||||
|
||||
|
||||
class ModelNullableRequired(BaseModel):
|
||||
int_val: int | None
|
||||
str_val: str | None
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def convert_fields(cls, v):
|
||||
return convert(v)
|
||||
|
||||
|
||||
@app.get("/model-nullable-required")
|
||||
async def read_model_nullable_required(
|
||||
params: Annotated[ModelNullableRequired, Cookie()],
|
||||
):
|
||||
return {
|
||||
"int_val": params.int_val,
|
||||
"str_val": params.str_val,
|
||||
"fields_set": params.model_fields_set,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required",
|
||||
"/model-nullable-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_schema(path: str):
|
||||
assert app.openapi()["paths"][path]["get"]["parameters"] == snapshot(
|
||||
[
|
||||
{
|
||||
"required": True,
|
||||
"schema": {
|
||||
"title": "Int Val",
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
},
|
||||
"name": "int_val",
|
||||
"in": "cookie",
|
||||
},
|
||||
{
|
||||
"required": True,
|
||||
"schema": {
|
||||
"title": "Str Val",
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
},
|
||||
"name": "str_val",
|
||||
"in": "cookie",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required",
|
||||
"/model-nullable-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_missing(path: str):
|
||||
client = TestClient(app)
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path)
|
||||
|
||||
assert mock_convert.call_count == 0, (
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
assert response.json() == snapshot(
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ["cookie", "int_val"],
|
||||
"msg": "Field required",
|
||||
"input": IsOneOf(None, {}),
|
||||
},
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ["cookie", "str_val"],
|
||||
"msg": "Field required",
|
||||
"input": IsOneOf(None, {}),
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required",
|
||||
"/model-nullable-required",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"values",
|
||||
[
|
||||
{"int_val": "1", "str_val": "test"},
|
||||
{"int_val": "0", "str_val": ""},
|
||||
],
|
||||
)
|
||||
def test_nullable_required_pass_value(path: str, values: dict[str, str]):
|
||||
client = TestClient(app)
|
||||
client.cookies.set("int_val", values["int_val"])
|
||||
client.cookies.set("str_val", values["str_val"])
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path)
|
||||
|
||||
assert mock_convert.call_count == 2, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": int(values["int_val"]),
|
||||
"str_val": values["str_val"],
|
||||
"fields_set": IsOneOf(None, IsList("int_val", "str_val", check_order=False)),
|
||||
}
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable with default=None
|
||||
|
||||
|
||||
@app.get("/nullable-non-required")
|
||||
async def read_nullable_non_required(
|
||||
int_val: Annotated[
|
||||
int | None,
|
||||
Cookie(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
str_val: Annotated[
|
||||
str | None,
|
||||
Cookie(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
):
|
||||
return {
|
||||
"int_val": int_val,
|
||||
"str_val": str_val,
|
||||
"fields_set": None,
|
||||
}
|
||||
|
||||
|
||||
class ModelNullableNonRequired(BaseModel):
|
||||
int_val: int | None = None
|
||||
str_val: str | None = None
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def convert_fields(cls, v):
|
||||
return convert(v)
|
||||
|
||||
|
||||
@app.get("/model-nullable-non-required")
|
||||
async def read_model_nullable_non_required(
|
||||
params: Annotated[ModelNullableNonRequired, Cookie()],
|
||||
):
|
||||
return {
|
||||
"int_val": params.int_val,
|
||||
"str_val": params.str_val,
|
||||
"fields_set": params.model_fields_set,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required",
|
||||
"/model-nullable-non-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_schema(path: str):
|
||||
assert app.openapi()["paths"][path]["get"]["parameters"] == snapshot(
|
||||
[
|
||||
{
|
||||
"required": False,
|
||||
"schema": {
|
||||
"title": "Int Val",
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
# "default": None, # `None` values are omitted in OpenAPI schema
|
||||
},
|
||||
"name": "int_val",
|
||||
"in": "cookie",
|
||||
},
|
||||
{
|
||||
"required": False,
|
||||
"schema": {
|
||||
"title": "Str Val",
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
# "default": None, # `None` values are omitted in OpenAPI schema
|
||||
},
|
||||
"name": "str_val",
|
||||
"in": "cookie",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required",
|
||||
"/model-nullable-non-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_missing(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path)
|
||||
|
||||
assert mock_convert.call_count == 0, (
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"int_val": None,
|
||||
"str_val": None,
|
||||
"fields_set": IsOneOf(None, []),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required",
|
||||
"/model-nullable-non-required",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"values",
|
||||
[
|
||||
{"int_val": "1", "str_val": "test"},
|
||||
{"int_val": "0", "str_val": ""},
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_pass_value(path: str, values: dict[str, str]):
|
||||
client = TestClient(app)
|
||||
client.cookies.set("int_val", values["int_val"])
|
||||
client.cookies.set("str_val", values["str_val"])
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path)
|
||||
|
||||
assert mock_convert.call_count == 2, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": int(values["int_val"]),
|
||||
"str_val": values["str_val"],
|
||||
"fields_set": IsOneOf(None, IsList("int_val", "str_val", check_order=False)),
|
||||
}
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable with not-None default
|
||||
|
||||
|
||||
@app.get("/nullable-with-non-null-default")
|
||||
async def read_nullable_with_non_null_default(
|
||||
*,
|
||||
int_val: Annotated[
|
||||
int | None,
|
||||
Cookie(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = -1,
|
||||
str_val: Annotated[
|
||||
str | None,
|
||||
Cookie(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = "default",
|
||||
):
|
||||
return {
|
||||
"int_val": int_val,
|
||||
"str_val": str_val,
|
||||
"fields_set": None,
|
||||
}
|
||||
|
||||
|
||||
class ModelNullableWithNonNullDefault(BaseModel):
|
||||
int_val: int | None = -1
|
||||
str_val: str | None = "default"
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def convert_fields(cls, v):
|
||||
return convert(v)
|
||||
|
||||
|
||||
@app.get("/model-nullable-with-non-null-default")
|
||||
async def read_model_nullable_with_non_null_default(
|
||||
params: Annotated[ModelNullableWithNonNullDefault, Cookie()],
|
||||
):
|
||||
return {
|
||||
"int_val": params.int_val,
|
||||
"str_val": params.str_val,
|
||||
"fields_set": params.model_fields_set,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default",
|
||||
"/model-nullable-with-non-null-default",
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_schema(path: str):
|
||||
assert app.openapi()["paths"][path]["get"]["parameters"] == snapshot(
|
||||
[
|
||||
{
|
||||
"required": False,
|
||||
"schema": {
|
||||
"title": "Int Val",
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
"default": -1,
|
||||
},
|
||||
"name": "int_val",
|
||||
"in": "cookie",
|
||||
},
|
||||
{
|
||||
"required": False,
|
||||
"schema": {
|
||||
"title": "Str Val",
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"default": "default",
|
||||
},
|
||||
"name": "str_val",
|
||||
"in": "cookie",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default",
|
||||
"/model-nullable-with-non-null-default",
|
||||
],
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Missing parameters are pre-populated with default values before validation"
|
||||
)
|
||||
def test_nullable_with_non_null_default_missing(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path)
|
||||
|
||||
assert mock_convert.call_count == 0, (
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 200 # pragma: no cover
|
||||
assert response.json() == { # pragma: no cover
|
||||
"int_val": -1,
|
||||
"str_val": "default",
|
||||
"fields_set": IsOneOf(None, []),
|
||||
}
|
||||
# TODO: Remove 'no cover' when the issue is fixed
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default",
|
||||
"/model-nullable-with-non-null-default",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"values",
|
||||
[
|
||||
{"int_val": "1", "str_val": "test"},
|
||||
{"int_val": "0", "str_val": ""},
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_pass_value(path: str, values: dict[str, str]):
|
||||
client = TestClient(app)
|
||||
client.cookies.set("int_val", values["int_val"])
|
||||
client.cookies.set("str_val", values["str_val"])
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path)
|
||||
|
||||
assert mock_convert.call_count == 2, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": int(values["int_val"]),
|
||||
"str_val": values["str_val"],
|
||||
"fields_set": IsOneOf(None, IsList("int_val", "str_val", check_order=False)),
|
||||
}
|
||||
@@ -0,0 +1,525 @@
|
||||
from typing import Annotated, Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from dirty_equals import IsOneOf
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from fastapi.testclient import TestClient
|
||||
from inline_snapshot import Is, snapshot
|
||||
from pydantic import BeforeValidator
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
|
||||
from .utils import get_body_model_name
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def convert(v: Any) -> Any:
|
||||
return v
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable required
|
||||
|
||||
|
||||
@app.post("/nullable-required-bytes")
|
||||
async def read_nullable_required_bytes(
|
||||
file: Annotated[
|
||||
bytes | None,
|
||||
File(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
files: Annotated[
|
||||
list[bytes] | None,
|
||||
File(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
):
|
||||
return {
|
||||
"file": len(file) if file is not None else None,
|
||||
"files": [len(f) for f in files] if files is not None else None,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/nullable-required-uploadfile")
|
||||
async def read_nullable_required_uploadfile(
|
||||
file: Annotated[
|
||||
UploadFile | None,
|
||||
File(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
files: Annotated[
|
||||
list[UploadFile] | None,
|
||||
File(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
):
|
||||
return {
|
||||
"file": file.size if file is not None else None,
|
||||
"files": [f.size for f in files] if files is not None else None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required-bytes",
|
||||
"/nullable-required-uploadfile",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_schema(path: str):
|
||||
openapi = app.openapi()
|
||||
body_model_name = get_body_model_name(openapi, path)
|
||||
|
||||
assert openapi["components"]["schemas"][body_model_name] == snapshot(
|
||||
{
|
||||
"properties": {
|
||||
"file": {
|
||||
"title": "File",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"contentMediaType": "application/octet-stream",
|
||||
},
|
||||
{"type": "null"},
|
||||
],
|
||||
},
|
||||
"files": {
|
||||
"title": "Files",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"contentMediaType": "application/octet-stream",
|
||||
},
|
||||
},
|
||||
{"type": "null"},
|
||||
],
|
||||
},
|
||||
},
|
||||
"required": ["file", "files"],
|
||||
"title": Is(body_model_name),
|
||||
"type": "object",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required-bytes",
|
||||
"/nullable-required-uploadfile",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_missing(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(path)
|
||||
|
||||
assert mock_convert.call_count == 0, (
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
assert response.json() == snapshot(
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ["body", "file"],
|
||||
"msg": "Field required",
|
||||
"input": IsOneOf(None, {}),
|
||||
},
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ["body", "files"],
|
||||
"msg": "Field required",
|
||||
"input": IsOneOf(None, {}),
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required-bytes",
|
||||
"/nullable-required-uploadfile",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_pass_empty_file(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path,
|
||||
files=[("file", b""), ("files", b""), ("files", b"")],
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 2, "Validator should be called for each field"
|
||||
call_args = [call_args_item.args for call_args_item in mock_convert.call_args_list]
|
||||
file_call_arg_1 = call_args[0][0]
|
||||
files_call_arg_1 = call_args[1][0]
|
||||
|
||||
assert (
|
||||
(file_call_arg_1 == b"") # file as bytes
|
||||
or isinstance(file_call_arg_1, StarletteUploadFile) # file as UploadFile
|
||||
)
|
||||
assert (
|
||||
(files_call_arg_1 == [b"", b""]) # files as bytes
|
||||
or all( # files as UploadFile
|
||||
isinstance(f, StarletteUploadFile) for f in files_call_arg_1
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"file": 0,
|
||||
"files": [0, 0],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required-bytes",
|
||||
"/nullable-required-uploadfile",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_pass_file(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path,
|
||||
files=[
|
||||
("file", b"test 1"),
|
||||
("files", b"test 2"),
|
||||
("files", b"test 3"),
|
||||
],
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 2, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"file": 6, "files": [6, 6]}
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable with default=None
|
||||
|
||||
|
||||
@app.post("/nullable-non-required-bytes")
|
||||
async def read_nullable_non_required_bytes(
|
||||
file: Annotated[
|
||||
bytes | None,
|
||||
File(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
files: Annotated[
|
||||
list[bytes] | None,
|
||||
File(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
):
|
||||
return {
|
||||
"file": len(file) if file is not None else None,
|
||||
"files": [len(f) for f in files] if files is not None else None,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/nullable-non-required-uploadfile")
|
||||
async def read_nullable_non_required_uploadfile(
|
||||
file: Annotated[
|
||||
UploadFile | None,
|
||||
File(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
files: Annotated[
|
||||
list[UploadFile] | None,
|
||||
File(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
):
|
||||
return {
|
||||
"file": file.size if file is not None else None,
|
||||
"files": [f.size for f in files] if files is not None else None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required-bytes",
|
||||
"/nullable-non-required-uploadfile",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_schema(path: str):
|
||||
openapi = app.openapi()
|
||||
body_model_name = get_body_model_name(openapi, path)
|
||||
|
||||
assert openapi["components"]["schemas"][body_model_name] == snapshot(
|
||||
{
|
||||
"properties": {
|
||||
"file": {
|
||||
"title": "File",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"contentMediaType": "application/octet-stream",
|
||||
},
|
||||
{"type": "null"},
|
||||
],
|
||||
# "default": None, # `None` values are omitted in OpenAPI schema
|
||||
},
|
||||
"files": {
|
||||
"title": "Files",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"contentMediaType": "application/octet-stream",
|
||||
},
|
||||
},
|
||||
{"type": "null"},
|
||||
],
|
||||
# "default": None, # `None` values are omitted in OpenAPI schema
|
||||
},
|
||||
},
|
||||
"title": Is(body_model_name),
|
||||
"type": "object",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required-bytes",
|
||||
"/nullable-non-required-uploadfile",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_missing(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(path)
|
||||
|
||||
assert mock_convert.call_count == 0, (
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"file": None,
|
||||
"files": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required-bytes",
|
||||
"/nullable-non-required-uploadfile",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_pass_empty_file(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path,
|
||||
files=[("file", b""), ("files", b""), ("files", b"")],
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 2, "Validator should be called for each field"
|
||||
call_args = [call_args_item.args for call_args_item in mock_convert.call_args_list]
|
||||
file_call_arg_1 = call_args[0][0]
|
||||
files_call_arg_1 = call_args[1][0]
|
||||
|
||||
assert (
|
||||
(file_call_arg_1 == b"") # file as bytes
|
||||
or isinstance(file_call_arg_1, StarletteUploadFile) # file as UploadFile
|
||||
)
|
||||
assert (
|
||||
(files_call_arg_1 == [b"", b""]) # files as bytes
|
||||
or all( # files as UploadFile
|
||||
isinstance(f, StarletteUploadFile) for f in files_call_arg_1
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"file": 0, "files": [0, 0]}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required-bytes",
|
||||
"/nullable-non-required-uploadfile",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_pass_file(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path,
|
||||
files=[("file", b"test 1"), ("files", b"test 2"), ("files", b"test 3")],
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 2, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"file": 6, "files": [6, 6]}
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable with not-None default
|
||||
|
||||
|
||||
@app.post("/nullable-with-non-null-default-bytes")
|
||||
async def read_nullable_with_non_null_default_bytes(
|
||||
*,
|
||||
file: Annotated[
|
||||
bytes | None,
|
||||
File(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = b"default",
|
||||
files: Annotated[
|
||||
list[bytes] | None,
|
||||
File(default_factory=lambda: [b"default"]),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
):
|
||||
return {
|
||||
"file": len(file) if file is not None else None,
|
||||
"files": [len(f) for f in files] if files is not None else None,
|
||||
}
|
||||
|
||||
|
||||
# Note: It seems to be not possible to create endpoint with UploadFile and non-None default
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default-bytes",
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_schema(path: str):
|
||||
openapi = app.openapi()
|
||||
body_model_name = get_body_model_name(openapi, path)
|
||||
|
||||
assert openapi["components"]["schemas"][body_model_name] == snapshot(
|
||||
{
|
||||
"properties": {
|
||||
"file": {
|
||||
"title": "File",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string",
|
||||
"contentMediaType": "application/octet-stream",
|
||||
},
|
||||
{"type": "null"},
|
||||
],
|
||||
"default": "default", # <= Default value here looks strange to me
|
||||
},
|
||||
"files": {
|
||||
"title": "Files",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"contentMediaType": "application/octet-stream",
|
||||
},
|
||||
},
|
||||
{"type": "null"},
|
||||
],
|
||||
},
|
||||
},
|
||||
"title": Is(body_model_name),
|
||||
"type": "object",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
pytest.param(
|
||||
"/nullable-with-non-null-default-bytes",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="AttributeError: 'bytes' object has no attribute 'read'",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_missing(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(path)
|
||||
|
||||
assert mock_convert.call_count == 0, ( # pragma: no cover
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 200 # pragma: no cover
|
||||
assert response.json() == {"file": None, "files": None} # pragma: no cover
|
||||
# TODO: Remove 'no cover' when the issue is fixed
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default-bytes",
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_pass_empty_file(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path,
|
||||
files=[("file", b""), ("files", b""), ("files", b"")],
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 2, "Validator should be called for each field"
|
||||
call_args = [call_args_item.args for call_args_item in mock_convert.call_args_list]
|
||||
file_call_arg_1 = call_args[0][0]
|
||||
files_call_arg_1 = call_args[1][0]
|
||||
|
||||
assert (
|
||||
(file_call_arg_1 == b"") # file as bytes
|
||||
or isinstance(file_call_arg_1, StarletteUploadFile) # file as UploadFile
|
||||
)
|
||||
assert (
|
||||
(files_call_arg_1 == [b"", b""]) # files as bytes
|
||||
or all( # files as UploadFile
|
||||
isinstance(f, StarletteUploadFile) for f in files_call_arg_1
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"file": 0, "files": [0, 0]}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default-bytes",
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_pass_file(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path,
|
||||
files=[("file", b"test 1"), ("files", b"test 2"), ("files", b"test 3")],
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 2, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {"file": 6, "files": [6, 6]}
|
||||
@@ -0,0 +1,746 @@
|
||||
from typing import Annotated, Any
|
||||
from unittest.mock import Mock, call, patch
|
||||
|
||||
import pytest
|
||||
from dirty_equals import IsList, IsOneOf, IsPartialDict
|
||||
from fastapi import FastAPI, Form
|
||||
from fastapi.testclient import TestClient
|
||||
from inline_snapshot import Is, snapshot
|
||||
from pydantic import BaseModel, BeforeValidator, field_validator
|
||||
|
||||
from .utils import get_body_model_name
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def convert(v: Any) -> Any:
|
||||
return v
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable required
|
||||
|
||||
|
||||
@app.post("/nullable-required")
|
||||
async def read_nullable_required(
|
||||
int_val: Annotated[
|
||||
int | None,
|
||||
Form(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
str_val: Annotated[
|
||||
str | None,
|
||||
Form(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
list_val: Annotated[
|
||||
list[int] | None,
|
||||
Form(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
):
|
||||
return {
|
||||
"int_val": int_val,
|
||||
"str_val": str_val,
|
||||
"list_val": list_val,
|
||||
"fields_set": None,
|
||||
}
|
||||
|
||||
|
||||
class ModelNullableRequired(BaseModel):
|
||||
int_val: int | None
|
||||
str_val: str | None
|
||||
list_val: list[int] | None
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
def convert_fields(cls, v):
|
||||
return convert(v)
|
||||
|
||||
|
||||
@app.post("/model-nullable-required")
|
||||
async def read_model_nullable_required(
|
||||
params: Annotated[ModelNullableRequired, Form()],
|
||||
):
|
||||
return {
|
||||
"int_val": params.int_val,
|
||||
"str_val": params.str_val,
|
||||
"list_val": params.list_val,
|
||||
"fields_set": params.model_fields_set,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required",
|
||||
"/model-nullable-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_schema(path: str):
|
||||
openapi = app.openapi()
|
||||
body_model_name = get_body_model_name(openapi, path)
|
||||
|
||||
assert openapi["components"]["schemas"][body_model_name] == snapshot(
|
||||
{
|
||||
"properties": {
|
||||
"int_val": {
|
||||
"title": "Int Val",
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
},
|
||||
"str_val": {
|
||||
"title": "Str Val",
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
},
|
||||
"list_val": {
|
||||
"title": "List Val",
|
||||
"anyOf": [
|
||||
{"type": "array", "items": {"type": "integer"}},
|
||||
{"type": "null"},
|
||||
],
|
||||
},
|
||||
},
|
||||
"required": ["int_val", "str_val", "list_val"],
|
||||
"title": Is(body_model_name),
|
||||
"type": "object",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required",
|
||||
"/model-nullable-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_missing(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(path)
|
||||
|
||||
assert mock_convert.call_count == 0, (
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
assert response.json() == snapshot(
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ["body", "int_val"],
|
||||
"msg": "Field required",
|
||||
"input": IsOneOf(None, {}),
|
||||
},
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ["body", "str_val"],
|
||||
"msg": "Field required",
|
||||
"input": IsOneOf(None, {}),
|
||||
},
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ["body", "list_val"],
|
||||
"msg": "Field required",
|
||||
"input": IsOneOf(None, {}),
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
pytest.param(
|
||||
"/nullable-required",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="Empty str is replaced with None even for required parameters"
|
||||
),
|
||||
),
|
||||
"/model-nullable-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_pass_empty_str_to_str_val(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path,
|
||||
data={
|
||||
"int_val": "0", # Empty string would cause validation error (see below)
|
||||
"str_val": "",
|
||||
"list_val": "0", # Empty string would cause validation error (see below)
|
||||
},
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 3, "Validator should be called for each field"
|
||||
assert mock_convert.call_args_list == [
|
||||
call("0"), # int_val
|
||||
call(""), # str_val
|
||||
call(["0"]), # list_val
|
||||
]
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": 0,
|
||||
"str_val": "",
|
||||
"list_val": [0],
|
||||
"fields_set": IsOneOf(
|
||||
None, IsList("int_val", "str_val", "list_val", check_order=False)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
pytest.param(
|
||||
"/nullable-required",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="Empty str is replaced with None even for required parameters"
|
||||
),
|
||||
),
|
||||
"/model-nullable-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_pass_empty_str_to_int_val_and_list_val(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path,
|
||||
data={
|
||||
"int_val": "",
|
||||
"str_val": "",
|
||||
"list_val": "",
|
||||
},
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 3, "Validator should be called for each field"
|
||||
assert mock_convert.call_args_list == [
|
||||
call(""), # int_val
|
||||
call(""), # str_val
|
||||
call([""]), # list_val
|
||||
]
|
||||
assert response.status_code == 422, response.text
|
||||
assert response.json() == snapshot(
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"input": "",
|
||||
"loc": ["body", "int_val"],
|
||||
"msg": "Input should be a valid integer, unable to parse string as an integer",
|
||||
"type": "int_parsing",
|
||||
},
|
||||
{
|
||||
"input": "",
|
||||
"loc": ["body", "list_val", 0],
|
||||
"msg": "Input should be a valid integer, unable to parse string as an integer",
|
||||
"type": "int_parsing",
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required",
|
||||
"/model-nullable-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_pass_value(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path, data={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]}
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 3, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": 1,
|
||||
"str_val": "test",
|
||||
"list_val": [1, 2],
|
||||
"fields_set": IsOneOf(
|
||||
None, IsList("int_val", "str_val", "list_val", check_order=False)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable with default=None
|
||||
|
||||
|
||||
@app.post("/nullable-non-required")
|
||||
async def read_nullable_non_required(
|
||||
int_val: Annotated[
|
||||
int | None,
|
||||
Form(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
str_val: Annotated[
|
||||
str | None,
|
||||
Form(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
list_val: Annotated[
|
||||
list[int] | None,
|
||||
Form(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
):
|
||||
return {
|
||||
"int_val": int_val,
|
||||
"str_val": str_val,
|
||||
"list_val": list_val,
|
||||
"fields_set": None,
|
||||
}
|
||||
|
||||
|
||||
class ModelNullableNonRequired(BaseModel):
|
||||
int_val: int | None = None
|
||||
str_val: str | None = None
|
||||
list_val: list[int] | None = None
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
def convert_fields(cls, v):
|
||||
return convert(v)
|
||||
|
||||
|
||||
@app.post("/model-nullable-non-required")
|
||||
async def read_model_nullable_non_required(
|
||||
params: Annotated[ModelNullableNonRequired, Form()],
|
||||
):
|
||||
return {
|
||||
"int_val": params.int_val,
|
||||
"str_val": params.str_val,
|
||||
"list_val": params.list_val,
|
||||
"fields_set": params.model_fields_set,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required",
|
||||
"/model-nullable-non-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_schema(path: str):
|
||||
openapi = app.openapi()
|
||||
body_model_name = get_body_model_name(openapi, path)
|
||||
|
||||
assert openapi["components"]["schemas"][body_model_name] == snapshot(
|
||||
{
|
||||
"properties": {
|
||||
"int_val": {
|
||||
"title": "Int Val",
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
# "default": None, # `None` values are omitted in OpenAPI schema
|
||||
},
|
||||
"str_val": {
|
||||
"title": "Str Val",
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
# "default": None, # `None` values are omitted in OpenAPI schema
|
||||
},
|
||||
"list_val": {
|
||||
"title": "List Val",
|
||||
"anyOf": [
|
||||
{"type": "array", "items": {"type": "integer"}},
|
||||
{"type": "null"},
|
||||
],
|
||||
# "default": None, # `None` values are omitted in OpenAPI schema
|
||||
},
|
||||
},
|
||||
"title": Is(body_model_name),
|
||||
"type": "object",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required",
|
||||
"/model-nullable-non-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_missing(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(path)
|
||||
|
||||
assert mock_convert.call_count == 0, (
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"int_val": None,
|
||||
"str_val": None,
|
||||
"list_val": None,
|
||||
"fields_set": IsOneOf(None, []),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required",
|
||||
pytest.param(
|
||||
"/model-nullable-non-required",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="Empty strings are not replaced with None for parameters declared as model"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_pass_empty_str_to_str_val_and_int_val(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path,
|
||||
data={
|
||||
"int_val": "",
|
||||
"str_val": "",
|
||||
"list_val": "0", # Empty string would cause validation error (see below)
|
||||
},
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 1, "Validator should be called for list_val only"
|
||||
assert mock_convert.call_args_list == [
|
||||
call(["0"]), # list_val
|
||||
]
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": None,
|
||||
"str_val": None,
|
||||
"list_val": [0],
|
||||
"fields_set": IsOneOf(None, IsList("list_val", check_order=False)),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required",
|
||||
pytest.param(
|
||||
"/model-nullable-non-required",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="Empty strings are not replaced with None for parameters declared as model"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_pass_empty_str_to_all(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path,
|
||||
data={
|
||||
"int_val": "",
|
||||
"str_val": "",
|
||||
"list_val": "",
|
||||
},
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 1, "Validator should be called for list_val only"
|
||||
assert mock_convert.call_args_list == [
|
||||
call([""]), # list_val
|
||||
]
|
||||
assert response.status_code == 422, response.text
|
||||
assert response.json() == snapshot(
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"input": "",
|
||||
"loc": ["body", "list_val", 0],
|
||||
"msg": "Input should be a valid integer, unable to parse string as an integer",
|
||||
"type": "int_parsing",
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required",
|
||||
"/model-nullable-non-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_pass_value(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path, data={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]}
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 3, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": 1,
|
||||
"str_val": "test",
|
||||
"list_val": [1, 2],
|
||||
"fields_set": IsOneOf(
|
||||
None, IsList("int_val", "str_val", "list_val", check_order=False)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable with not-None default
|
||||
|
||||
|
||||
@app.post("/nullable-with-non-null-default")
|
||||
async def read_nullable_with_non_null_default(
|
||||
*,
|
||||
int_val: Annotated[
|
||||
int | None,
|
||||
Form(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = -1,
|
||||
str_val: Annotated[
|
||||
str | None,
|
||||
Form(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = "default",
|
||||
list_val: Annotated[
|
||||
list[int] | None,
|
||||
Form(default_factory=lambda: [0]),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
):
|
||||
return {
|
||||
"int_val": int_val,
|
||||
"str_val": str_val,
|
||||
"list_val": list_val,
|
||||
"fields_set": None,
|
||||
}
|
||||
|
||||
|
||||
class ModelNullableWithNonNullDefault(BaseModel):
|
||||
int_val: int | None = -1
|
||||
str_val: str | None = "default"
|
||||
list_val: list[int] | None = [0]
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
def convert_fields(cls, v):
|
||||
return convert(v)
|
||||
|
||||
|
||||
@app.post("/model-nullable-with-non-null-default")
|
||||
async def read_model_nullable_with_non_null_default(
|
||||
params: Annotated[ModelNullableWithNonNullDefault, Form()],
|
||||
):
|
||||
return {
|
||||
"int_val": params.int_val,
|
||||
"str_val": params.str_val,
|
||||
"list_val": params.list_val,
|
||||
"fields_set": params.model_fields_set,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default",
|
||||
"/model-nullable-with-non-null-default",
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_schema(path: str):
|
||||
openapi = app.openapi()
|
||||
body_model_name = get_body_model_name(openapi, path)
|
||||
body_model = openapi["components"]["schemas"][body_model_name]
|
||||
|
||||
assert body_model == snapshot(
|
||||
{
|
||||
"properties": {
|
||||
"int_val": {
|
||||
"title": "Int Val",
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
"default": -1,
|
||||
},
|
||||
"str_val": {
|
||||
"title": "Str Val",
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"default": "default",
|
||||
},
|
||||
"list_val": IsPartialDict(
|
||||
{
|
||||
"title": "List Val",
|
||||
"anyOf": [
|
||||
{"type": "array", "items": {"type": "integer"}},
|
||||
{"type": "null"},
|
||||
],
|
||||
}
|
||||
),
|
||||
},
|
||||
"title": Is(body_model_name),
|
||||
"type": "object",
|
||||
}
|
||||
)
|
||||
|
||||
if path == "/model-nullable-with-non-null-default":
|
||||
# Check default value for list_val param for model-based parameters only.
|
||||
# default_factory is not reflected in OpenAPI schema
|
||||
assert body_model["properties"]["list_val"]["default"] == [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default",
|
||||
"/model-nullable-with-non-null-default",
|
||||
],
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Missing parameters are pre-populated with default values before validation"
|
||||
)
|
||||
def test_nullable_with_non_null_default_missing(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(path)
|
||||
|
||||
assert mock_convert.call_count == 0, (
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 200 # pragma: no cover
|
||||
assert response.json() == { # pragma: no cover
|
||||
"int_val": -1,
|
||||
"str_val": "default",
|
||||
"list_val": [0],
|
||||
"fields_set": IsOneOf(None, []),
|
||||
}
|
||||
# TODO: Remove 'no cover' when the issue is fixed
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
pytest.param(
|
||||
"/nullable-with-non-null-default",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="Empty strings are replaced with default values before validation"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
"/model-nullable-with-non-null-default",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="Empty strings are not replaced with None for parameters declared as model"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_pass_empty_str_to_str_val_and_int_val(
|
||||
path: str,
|
||||
):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path,
|
||||
data={
|
||||
"int_val": "",
|
||||
"str_val": "",
|
||||
"list_val": "0", # Empty string would cause validation error (see below)
|
||||
},
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 1, "Validator should be called for list_val only"
|
||||
assert mock_convert.call_args_list == [ # pragma: no cover
|
||||
call(["0"]), # list_val
|
||||
]
|
||||
assert response.status_code == 200, response.text # pragma: no cover
|
||||
assert response.json() == { # pragma: no cover
|
||||
"int_val": -1,
|
||||
"str_val": "default",
|
||||
"list_val": [0],
|
||||
"fields_set": IsOneOf(None, IsList("list_val", check_order=False)),
|
||||
}
|
||||
# TODO: Remove 'no cover' when the issue is fixed
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
pytest.param(
|
||||
"/nullable-with-non-null-default",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="Empty strings are replaced with default values before validation"
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
"/model-nullable-with-non-null-default",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="Empty strings are not replaced with None for parameters declared as model"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_pass_empty_str_to_all(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path,
|
||||
data={
|
||||
"int_val": "",
|
||||
"str_val": "",
|
||||
"list_val": "",
|
||||
},
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 1, "Validator should be called for list_val only"
|
||||
assert mock_convert.call_args_list == [ # pragma: no cover
|
||||
call([""]), # list_val
|
||||
]
|
||||
assert response.status_code == 422, response.text # pragma: no cover
|
||||
assert response.json() == snapshot( # pragma: no cover
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"input": "",
|
||||
"loc": ["body", "list_val", 0],
|
||||
"msg": "Input should be a valid integer, unable to parse string as an integer",
|
||||
"type": "int_parsing",
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
# TODO: Remove 'no cover' when the issue is fixed
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default",
|
||||
"/model-nullable-with-non-null-default",
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_pass_value(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.post(
|
||||
path, data={"int_val": "1", "str_val": "test", "list_val": ["1", "2"]}
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 3, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": 1,
|
||||
"str_val": "test",
|
||||
"list_val": [1, 2],
|
||||
"fields_set": IsOneOf(
|
||||
None, IsList("int_val", "str_val", "list_val", check_order=False)
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,634 @@
|
||||
from typing import Annotated, Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from dirty_equals import AnyThing, IsList, IsOneOf, IsPartialDict
|
||||
from fastapi import FastAPI, Header
|
||||
from fastapi.testclient import TestClient
|
||||
from inline_snapshot import snapshot
|
||||
from pydantic import BaseModel, BeforeValidator, field_validator
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def convert(v: Any) -> Any:
|
||||
return v
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable required
|
||||
|
||||
|
||||
@app.get("/nullable-required")
|
||||
async def read_nullable_required(
|
||||
int_val: Annotated[
|
||||
int | None,
|
||||
Header(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
str_val: Annotated[
|
||||
str | None,
|
||||
Header(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
list_val: Annotated[
|
||||
list[int] | None,
|
||||
Header(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
):
|
||||
return {
|
||||
"int_val": int_val,
|
||||
"str_val": str_val,
|
||||
"list_val": list_val,
|
||||
"fields_set": None,
|
||||
}
|
||||
|
||||
|
||||
class ModelNullableRequired(BaseModel):
|
||||
int_val: int | None
|
||||
str_val: str | None
|
||||
list_val: list[int] | None
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def convert_fields(cls, v):
|
||||
return convert(v)
|
||||
|
||||
|
||||
@app.get("/model-nullable-required")
|
||||
async def read_model_nullable_required(
|
||||
params: Annotated[ModelNullableRequired, Header()],
|
||||
):
|
||||
return {
|
||||
"int_val": params.int_val,
|
||||
"str_val": params.str_val,
|
||||
"list_val": params.list_val,
|
||||
"fields_set": params.model_fields_set,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
pytest.param(
|
||||
"/nullable-required",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="Title contains hyphens for single Header parameters"
|
||||
),
|
||||
),
|
||||
"/model-nullable-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_schema(path: str):
|
||||
assert app.openapi()["paths"][path]["get"]["parameters"] == snapshot(
|
||||
[
|
||||
{
|
||||
"required": True,
|
||||
"schema": {
|
||||
"title": "Int Val",
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
},
|
||||
"name": "int-val",
|
||||
"in": "header",
|
||||
},
|
||||
{
|
||||
"required": True,
|
||||
"schema": {
|
||||
"title": "Str Val",
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
},
|
||||
"name": "str-val",
|
||||
"in": "header",
|
||||
},
|
||||
{
|
||||
"required": True,
|
||||
"schema": {
|
||||
"title": "List Val",
|
||||
"anyOf": [
|
||||
{"type": "array", "items": {"type": "integer"}},
|
||||
{"type": "null"},
|
||||
],
|
||||
},
|
||||
"name": "list-val",
|
||||
"in": "header",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required",
|
||||
pytest.param(
|
||||
"/model-nullable-required",
|
||||
marks=pytest.mark.xfail(
|
||||
reason=(
|
||||
"For parameters declared as model, underscores are not replaced "
|
||||
"with hyphens in error loc"
|
||||
)
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_nullable_required_missing(path: str):
|
||||
client = TestClient(app)
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path)
|
||||
|
||||
assert mock_convert.call_count == 0, (
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
assert response.json() == snapshot(
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ["header", "int-val"],
|
||||
"msg": "Field required",
|
||||
"input": AnyThing(),
|
||||
},
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ["header", "str-val"],
|
||||
"msg": "Field required",
|
||||
"input": AnyThing(),
|
||||
},
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ["header", "list-val"],
|
||||
"msg": "Field required",
|
||||
"input": AnyThing(),
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required",
|
||||
"/model-nullable-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_pass_value(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(
|
||||
path,
|
||||
headers=[
|
||||
("int-val", "1"),
|
||||
("str-val", "test"),
|
||||
("list-val", "1"),
|
||||
("list-val", "2"),
|
||||
],
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 3, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": 1,
|
||||
"str_val": "test",
|
||||
"list_val": [1, 2],
|
||||
"fields_set": IsOneOf(
|
||||
None, IsList("int_val", "str_val", "list_val", check_order=False)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required",
|
||||
"/model-nullable-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_pass_empty_str_to_str_val(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(
|
||||
path,
|
||||
headers=[
|
||||
("int-val", "1"),
|
||||
("str-val", ""),
|
||||
("list-val", "1"),
|
||||
],
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 3, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": 1,
|
||||
"str_val": "",
|
||||
"list_val": [1],
|
||||
"fields_set": IsOneOf(
|
||||
None, IsList("int_val", "str_val", "list_val", check_order=False)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable with default=None
|
||||
|
||||
|
||||
@app.get("/nullable-non-required")
|
||||
async def read_nullable_non_required(
|
||||
int_val: Annotated[
|
||||
int | None,
|
||||
Header(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
str_val: Annotated[
|
||||
str | None,
|
||||
Header(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
list_val: Annotated[
|
||||
list[int] | None,
|
||||
Header(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
):
|
||||
return {
|
||||
"int_val": int_val,
|
||||
"str_val": str_val,
|
||||
"list_val": list_val,
|
||||
"fields_set": None,
|
||||
}
|
||||
|
||||
|
||||
class ModelNullableNonRequired(BaseModel):
|
||||
int_val: int | None = None
|
||||
str_val: str | None = None
|
||||
list_val: list[int] | None = None
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def convert_fields(cls, v):
|
||||
return convert(v)
|
||||
|
||||
|
||||
@app.get("/model-nullable-non-required")
|
||||
async def read_model_nullable_non_required(
|
||||
params: Annotated[ModelNullableNonRequired, Header()],
|
||||
):
|
||||
return {
|
||||
"int_val": params.int_val,
|
||||
"str_val": params.str_val,
|
||||
"list_val": params.list_val,
|
||||
"fields_set": params.model_fields_set,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
pytest.param(
|
||||
"/nullable-non-required",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="Title contains hyphens for single Header parameters"
|
||||
),
|
||||
),
|
||||
"/model-nullable-non-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_schema(path: str):
|
||||
assert app.openapi()["paths"][path]["get"]["parameters"] == snapshot(
|
||||
[
|
||||
{
|
||||
"required": False,
|
||||
"schema": {
|
||||
"title": "Int Val",
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
# "default": None, # `None` values are omitted in OpenAPI schema
|
||||
},
|
||||
"name": "int-val",
|
||||
"in": "header",
|
||||
},
|
||||
{
|
||||
"required": False,
|
||||
"schema": {
|
||||
"title": "Str Val",
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
# "default": None, # `None` values are omitted in OpenAPI schema
|
||||
},
|
||||
"name": "str-val",
|
||||
"in": "header",
|
||||
},
|
||||
{
|
||||
"required": False,
|
||||
"schema": {
|
||||
"title": "List Val",
|
||||
"anyOf": [
|
||||
{"type": "array", "items": {"type": "integer"}},
|
||||
{"type": "null"},
|
||||
],
|
||||
# "default": None, # `None` values are omitted in OpenAPI schema
|
||||
},
|
||||
"name": "list-val",
|
||||
"in": "header",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required",
|
||||
"/model-nullable-non-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_missing(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path)
|
||||
|
||||
assert mock_convert.call_count == 0, (
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"int_val": None,
|
||||
"str_val": None,
|
||||
"list_val": None,
|
||||
"fields_set": IsOneOf(None, []),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required",
|
||||
"/model-nullable-non-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_pass_value(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(
|
||||
path,
|
||||
headers=[
|
||||
("int-val", "1"),
|
||||
("str-val", "test"),
|
||||
("list-val", "1"),
|
||||
("list-val", "2"),
|
||||
],
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 3, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": 1,
|
||||
"str_val": "test",
|
||||
"list_val": [1, 2],
|
||||
"fields_set": IsOneOf(
|
||||
None, IsList("int_val", "str_val", "list_val", check_order=False)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required",
|
||||
"/model-nullable-non-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_pass_empty_str_to_str_val(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(
|
||||
path,
|
||||
headers=[
|
||||
("int-val", "1"),
|
||||
("str-val", ""),
|
||||
("list-val", "1"),
|
||||
],
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 3, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": 1,
|
||||
"str_val": "",
|
||||
"list_val": [1],
|
||||
"fields_set": IsOneOf(
|
||||
None, IsList("int_val", "str_val", "list_val", check_order=False)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable with not-None default
|
||||
|
||||
|
||||
@app.get("/nullable-with-non-null-default")
|
||||
async def read_nullable_with_non_null_default(
|
||||
*,
|
||||
int_val: Annotated[
|
||||
int | None,
|
||||
Header(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = -1,
|
||||
str_val: Annotated[
|
||||
str | None,
|
||||
Header(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = "default",
|
||||
list_val: Annotated[
|
||||
list[int] | None,
|
||||
Header(default_factory=lambda: [0]),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
):
|
||||
return {
|
||||
"int_val": int_val,
|
||||
"str_val": str_val,
|
||||
"list_val": list_val,
|
||||
"fields_set": None,
|
||||
}
|
||||
|
||||
|
||||
class ModelNullableWithNonNullDefault(BaseModel):
|
||||
int_val: int | None = -1
|
||||
str_val: str | None = "default"
|
||||
list_val: list[int] | None = [0]
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def convert_fields(cls, v):
|
||||
return convert(v)
|
||||
|
||||
|
||||
@app.get("/model-nullable-with-non-null-default")
|
||||
async def read_model_nullable_with_non_null_default(
|
||||
params: Annotated[ModelNullableWithNonNullDefault, Header()],
|
||||
):
|
||||
return {
|
||||
"int_val": params.int_val,
|
||||
"str_val": params.str_val,
|
||||
"list_val": params.list_val,
|
||||
"fields_set": params.model_fields_set,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
pytest.param(
|
||||
"/nullable-with-non-null-default",
|
||||
marks=pytest.mark.xfail(
|
||||
reason="Title contains hyphens for single Header parameters"
|
||||
),
|
||||
),
|
||||
"/model-nullable-with-non-null-default",
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_schema(path: str):
|
||||
parameters = app.openapi()["paths"][path]["get"]["parameters"]
|
||||
assert parameters == snapshot(
|
||||
[
|
||||
{
|
||||
"required": False,
|
||||
"schema": {
|
||||
"title": "Int Val",
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
"default": -1,
|
||||
},
|
||||
"name": "int-val",
|
||||
"in": "header",
|
||||
},
|
||||
{
|
||||
"required": False,
|
||||
"schema": {
|
||||
"title": "Str Val",
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"default": "default",
|
||||
},
|
||||
"name": "str-val",
|
||||
"in": "header",
|
||||
},
|
||||
{
|
||||
"required": False,
|
||||
"schema": IsPartialDict(
|
||||
{
|
||||
"title": "List Val",
|
||||
"anyOf": [
|
||||
{"type": "array", "items": {"type": "integer"}},
|
||||
{"type": "null"},
|
||||
],
|
||||
}
|
||||
),
|
||||
"name": "list-val",
|
||||
"in": "header",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
if path == "/model-nullable-with-non-null-default":
|
||||
# Check default value for list_val param for model-based parameters only.
|
||||
# default_factory is not reflected in OpenAPI schema
|
||||
assert parameters[2]["schema"]["default"] == [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default",
|
||||
"/model-nullable-with-non-null-default",
|
||||
],
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Missing parameters are pre-populated with default values before validation"
|
||||
)
|
||||
def test_nullable_with_non_null_default_missing(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path)
|
||||
|
||||
assert mock_convert.call_count == 0, (
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 200 # pragma: no cover
|
||||
assert response.json() == { # pragma: no cover
|
||||
"int_val": -1,
|
||||
"str_val": "default",
|
||||
"list_val": [0],
|
||||
"fields_set": IsOneOf(None, []),
|
||||
}
|
||||
# TODO: Remove 'no cover' when the issue is fixed
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default",
|
||||
"/model-nullable-with-non-null-default",
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_pass_value(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(
|
||||
path,
|
||||
headers=[
|
||||
("int-val", "1"),
|
||||
("str-val", "test"),
|
||||
("list-val", "1"),
|
||||
("list-val", "2"),
|
||||
],
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 3, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": 1,
|
||||
"str_val": "test",
|
||||
"list_val": [1, 2],
|
||||
"fields_set": IsOneOf(
|
||||
None, IsList("int_val", "str_val", "list_val", check_order=False)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default",
|
||||
"/model-nullable-with-non-null-default",
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_pass_empty_str_to_str_val(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(
|
||||
path,
|
||||
headers=[
|
||||
("int-val", "1"),
|
||||
("str-val", ""),
|
||||
("list-val", "1"),
|
||||
],
|
||||
)
|
||||
|
||||
assert mock_convert.call_count == 3, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": 1,
|
||||
"str_val": "",
|
||||
"list_val": [1],
|
||||
"fields_set": IsOneOf(
|
||||
None, IsList("int_val", "str_val", "list_val", check_order=False)
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
# Not appllicable for Path parameters
|
||||
# Path parameters cannot have default values or be nullable
|
||||
@@ -0,0 +1,507 @@
|
||||
from typing import Annotated, Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from dirty_equals import IsList, IsOneOf, IsPartialDict
|
||||
from fastapi import FastAPI, Query
|
||||
from fastapi.testclient import TestClient
|
||||
from inline_snapshot import snapshot
|
||||
from pydantic import BaseModel, BeforeValidator, field_validator
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def convert(v: Any) -> Any:
|
||||
return v
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable required
|
||||
|
||||
|
||||
@app.get("/nullable-required")
|
||||
async def read_nullable_required(
|
||||
int_val: Annotated[
|
||||
int | None,
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
str_val: Annotated[
|
||||
str | None,
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
list_val: Annotated[
|
||||
list[int] | None,
|
||||
Query(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
):
|
||||
return {
|
||||
"int_val": int_val,
|
||||
"str_val": str_val,
|
||||
"list_val": list_val,
|
||||
"fields_set": None,
|
||||
}
|
||||
|
||||
|
||||
class ModelNullableRequired(BaseModel):
|
||||
int_val: int | None
|
||||
str_val: str | None
|
||||
list_val: list[int] | None
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def convert_all(cls, v: Any) -> Any:
|
||||
return convert(v)
|
||||
|
||||
|
||||
@app.get("/model-nullable-required")
|
||||
async def read_model_nullable_required(
|
||||
params: Annotated[ModelNullableRequired, Query()],
|
||||
):
|
||||
return {
|
||||
"int_val": params.int_val,
|
||||
"str_val": params.str_val,
|
||||
"list_val": params.list_val,
|
||||
"fields_set": params.model_fields_set,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required",
|
||||
"/model-nullable-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_schema(path: str):
|
||||
assert app.openapi()["paths"][path]["get"]["parameters"] == snapshot(
|
||||
[
|
||||
{
|
||||
"required": True,
|
||||
"schema": {
|
||||
"title": "Int Val",
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
},
|
||||
"name": "int_val",
|
||||
"in": "query",
|
||||
},
|
||||
{
|
||||
"required": True,
|
||||
"schema": {
|
||||
"title": "Str Val",
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
},
|
||||
"name": "str_val",
|
||||
"in": "query",
|
||||
},
|
||||
{
|
||||
"in": "query",
|
||||
"name": "list_val",
|
||||
"required": True,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{"items": {"type": "integer"}, "type": "array"},
|
||||
{"type": "null"},
|
||||
],
|
||||
"title": "List Val",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required",
|
||||
"/model-nullable-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_required_missing(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path)
|
||||
|
||||
assert mock_convert.call_count == 0, (
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
assert response.json() == snapshot(
|
||||
{
|
||||
"detail": [
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ["query", "int_val"],
|
||||
"msg": "Field required",
|
||||
"input": IsOneOf(None, {}),
|
||||
},
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ["query", "str_val"],
|
||||
"msg": "Field required",
|
||||
"input": IsOneOf(None, {}),
|
||||
},
|
||||
{
|
||||
"type": "missing",
|
||||
"loc": ["query", "list_val"],
|
||||
"msg": "Field required",
|
||||
"input": IsOneOf(None, {}),
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-required",
|
||||
"/model-nullable-required",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"values",
|
||||
[
|
||||
{"int_val": "1", "str_val": "test", "list_val": ["1", "2"]},
|
||||
{"int_val": "0", "str_val": "", "list_val": ["0"]},
|
||||
],
|
||||
)
|
||||
def test_nullable_required_pass_value(path: str, values: dict[str, Any]):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path, params=values)
|
||||
|
||||
assert mock_convert.call_count == 3, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": int(values["int_val"]),
|
||||
"str_val": values["str_val"],
|
||||
"list_val": [int(v) for v in values["list_val"]],
|
||||
"fields_set": IsOneOf(
|
||||
None, IsList("int_val", "str_val", "list_val", check_order=False)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable with default=None
|
||||
|
||||
|
||||
@app.get("/nullable-non-required")
|
||||
async def read_nullable_non_required(
|
||||
int_val: Annotated[
|
||||
int | None,
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
str_val: Annotated[
|
||||
str | None,
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
list_val: Annotated[
|
||||
list[int] | None,
|
||||
Query(),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = None,
|
||||
):
|
||||
return {
|
||||
"int_val": int_val,
|
||||
"str_val": str_val,
|
||||
"list_val": list_val,
|
||||
"fields_set": None,
|
||||
}
|
||||
|
||||
|
||||
class ModelNullableNonRequired(BaseModel):
|
||||
int_val: int | None = None
|
||||
str_val: str | None = None
|
||||
list_val: list[int] | None = None
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def convert_all(cls, v: Any) -> Any:
|
||||
return convert(v)
|
||||
|
||||
|
||||
@app.get("/model-nullable-non-required")
|
||||
async def read_model_nullable_non_required(
|
||||
params: Annotated[ModelNullableNonRequired, Query()],
|
||||
):
|
||||
return {
|
||||
"int_val": params.int_val,
|
||||
"str_val": params.str_val,
|
||||
"list_val": params.list_val,
|
||||
"fields_set": params.model_fields_set,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required",
|
||||
"/model-nullable-non-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_schema(path: str):
|
||||
assert app.openapi()["paths"][path]["get"]["parameters"] == snapshot(
|
||||
[
|
||||
{
|
||||
"required": False,
|
||||
"schema": {
|
||||
"title": "Int Val",
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
# "default": None, # `None` values are omitted in OpenAPI schema
|
||||
},
|
||||
"name": "int_val",
|
||||
"in": "query",
|
||||
},
|
||||
{
|
||||
"required": False,
|
||||
"schema": {
|
||||
"title": "Str Val",
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
# "default": None, # `None` values are omitted in OpenAPI schema
|
||||
},
|
||||
"name": "str_val",
|
||||
"in": "query",
|
||||
},
|
||||
{
|
||||
"in": "query",
|
||||
"name": "list_val",
|
||||
"required": False,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{"items": {"type": "integer"}, "type": "array"},
|
||||
{"type": "null"},
|
||||
],
|
||||
"title": "List Val",
|
||||
# "default": None, # `None` values are omitted in OpenAPI schema
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required",
|
||||
"/model-nullable-non-required",
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_missing(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path)
|
||||
|
||||
assert mock_convert.call_count == 0, (
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"int_val": None,
|
||||
"str_val": None,
|
||||
"list_val": None,
|
||||
"fields_set": IsOneOf(None, []),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-non-required",
|
||||
"/model-nullable-non-required",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"values",
|
||||
[
|
||||
{"int_val": "1", "str_val": "test", "list_val": ["1", "2"]},
|
||||
{"int_val": "0", "str_val": "", "list_val": ["0"]},
|
||||
],
|
||||
)
|
||||
def test_nullable_non_required_pass_value(path: str, values: dict[str, Any]):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path, params=values)
|
||||
|
||||
assert mock_convert.call_count == 3, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": int(values["int_val"]),
|
||||
"str_val": values["str_val"],
|
||||
"list_val": [int(v) for v in values["list_val"]],
|
||||
"fields_set": IsOneOf(
|
||||
None, IsList("int_val", "str_val", "list_val", check_order=False)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# Nullable with not-None default
|
||||
|
||||
|
||||
@app.get("/nullable-with-non-null-default")
|
||||
async def read_nullable_with_non_null_default(
|
||||
*,
|
||||
int_val: Annotated[
|
||||
int | None,
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = -1,
|
||||
str_val: Annotated[
|
||||
str | None,
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
] = "default",
|
||||
list_val: Annotated[
|
||||
list[int] | None,
|
||||
Query(default_factory=lambda: [0]),
|
||||
BeforeValidator(lambda v: convert(v)),
|
||||
],
|
||||
):
|
||||
return {
|
||||
"int_val": int_val,
|
||||
"str_val": str_val,
|
||||
"list_val": list_val,
|
||||
"fields_set": None,
|
||||
}
|
||||
|
||||
|
||||
class ModelNullableWithNonNullDefault(BaseModel):
|
||||
int_val: int | None = -1
|
||||
str_val: str | None = "default"
|
||||
list_val: list[int] | None = [0]
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def convert_all(cls, v: Any) -> Any:
|
||||
return convert(v)
|
||||
|
||||
|
||||
@app.get("/model-nullable-with-non-null-default")
|
||||
async def read_model_nullable_with_non_null_default(
|
||||
params: Annotated[ModelNullableWithNonNullDefault, Query()],
|
||||
):
|
||||
return {
|
||||
"int_val": params.int_val,
|
||||
"str_val": params.str_val,
|
||||
"list_val": params.list_val,
|
||||
"fields_set": params.model_fields_set,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default",
|
||||
"/model-nullable-with-non-null-default",
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_schema(path: str):
|
||||
parameters = app.openapi()["paths"][path]["get"]["parameters"]
|
||||
assert parameters == snapshot(
|
||||
[
|
||||
{
|
||||
"required": False,
|
||||
"schema": {
|
||||
"title": "Int Val",
|
||||
"anyOf": [{"type": "integer"}, {"type": "null"}],
|
||||
"default": -1,
|
||||
},
|
||||
"name": "int_val",
|
||||
"in": "query",
|
||||
},
|
||||
{
|
||||
"required": False,
|
||||
"schema": {
|
||||
"title": "Str Val",
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
"default": "default",
|
||||
},
|
||||
"name": "str_val",
|
||||
"in": "query",
|
||||
},
|
||||
{
|
||||
"in": "query",
|
||||
"name": "list_val",
|
||||
"required": False,
|
||||
"schema": IsPartialDict(
|
||||
{
|
||||
"anyOf": [
|
||||
{"items": {"type": "integer"}, "type": "array"},
|
||||
{"type": "null"},
|
||||
],
|
||||
"title": "List Val",
|
||||
}
|
||||
),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
if path == "/model-nullable-with-non-null-default":
|
||||
# Check default value for list_val param for model-based parameters only.
|
||||
# default_factory is not reflected in OpenAPI schema
|
||||
assert parameters[2]["schema"]["default"] == [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default",
|
||||
"/model-nullable-with-non-null-default",
|
||||
],
|
||||
)
|
||||
@pytest.mark.xfail(
|
||||
reason="Missing parameters are pre-populated with default values before validation"
|
||||
)
|
||||
def test_nullable_with_non_null_default_missing(path: str):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path)
|
||||
|
||||
assert mock_convert.call_count == 0, (
|
||||
"Validator should not be called if the value is missing"
|
||||
)
|
||||
assert response.status_code == 200 # pragma: no cover
|
||||
assert response.json() == { # pragma: no cover
|
||||
"int_val": -1,
|
||||
"str_val": "default",
|
||||
"list_val": [0],
|
||||
"fields_set": IsOneOf(None, []),
|
||||
}
|
||||
# TODO: Remove 'no cover' when the issue is fixed
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/nullable-with-non-null-default",
|
||||
"/model-nullable-with-non-null-default",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"values",
|
||||
[
|
||||
{"int_val": "1", "str_val": "test", "list_val": ["1", "2"]},
|
||||
{"int_val": "0", "str_val": "", "list_val": ["0"]},
|
||||
],
|
||||
)
|
||||
def test_nullable_with_non_null_default_pass_value(path: str, values: dict[str, Any]):
|
||||
client = TestClient(app)
|
||||
|
||||
with patch(f"{__name__}.convert", Mock(wraps=convert)) as mock_convert:
|
||||
response = client.get(path, params=values)
|
||||
|
||||
assert mock_convert.call_count == 3, "Validator should be called for each field"
|
||||
assert response.status_code == 200, response.text
|
||||
assert response.json() == {
|
||||
"int_val": int(values["int_val"]),
|
||||
"str_val": values["str_val"],
|
||||
"list_val": [int(v) for v in values["list_val"]],
|
||||
"fields_set": IsOneOf(
|
||||
None, IsList("int_val", "str_val", "list_val", check_order=False)
|
||||
),
|
||||
}
|
||||
Reference in New Issue
Block a user