Compare commits

..

18 Commits

Author SHA1 Message Date
Sebastián Ramírez
a6897963d5 🔖 Release version 0.61.0 2020-08-09 22:36:47 +02:00
Sebastián Ramírez
c14803e3fa 📝 Update release notes 2020-08-09 22:34:50 +02:00
Sebastián Ramírez
cdba8481c2 🔥 Remove old/unused parameter sqlalchemy_safe from jsonable_encoder (#1864) 2020-08-09 22:32:59 +02:00
Sebastián Ramírez
64ca596a11 📝 Update release notes 2020-08-09 22:19:46 +02:00
Sebastián Ramírez
e1758d107e ⬆ Require Pydantic > 1.0 (#1862)
* 🔥 Remove support for Pydantic < 1.0

* 🔥 Remove deprecated skip_defaults from jsonable_encoder and set default for exclude to None, as in Pydantic

* ♻️ Set default of response_model_exclude=None as in Pydantic

* ⬆️ Require Pydantic >=1.0.0 in requirements
2020-08-09 22:17:08 +02:00
Sebastián Ramírez
3390182fc9 📝 Update release notes 2020-08-09 17:04:35 +02:00
Sebastián Ramírez
912a43ee3d 📝 Add link to TestDriven.io course in docs (#1860) 2020-08-09 17:02:51 +02:00
Sebastián Ramírez
4020304945 📝 Update release notes 2020-08-09 16:43:07 +02:00
Elliana May
0a2fc78fea ✏️ Update docs to remove gender-specific references (#1824)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
2020-08-09 16:35:44 +02:00
Sebastián Ramírez
b91acfb457 📝 Update release notes 2020-08-09 15:58:58 +02:00
Nik
b9a0179a03 Add support for injecting HTTPConnection (#1827) 2020-08-09 15:56:41 +02:00
Rupsi Kaushik
5ed48ccdc8 Export WebSocketDisconnect and add example handling disconnections to docs (#1822)
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
2020-08-09 15:52:19 +02:00
Sebastián Ramírez
f32d67c8b9 📝 Update release notes 2020-08-09 13:12:09 +02:00
manlix
0752c7242d 🔊 Fix empty log message in docs example about raised exceptions (#1815) 2020-08-09 13:10:33 +02:00
Sebastián Ramírez
be855c696b 📝 Update release notes 2020-08-09 12:56:09 +02:00
Nima Mashhadi M. Reza
da9b5201c4 🔧 Add Flake8 linting (#1774)
Co-authored-by: nimashadix <nimashadix@pop-os.localdomain>
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
2020-08-09 12:54:05 +02:00
Sebastián Ramírez
4daa6ef4e4 📝 Update release notes 2020-08-09 11:08:29 +02:00
Sebastián Ramírez
4ad7c55618 💚 Disable Gitter notification as it's currently broken (#1853)
...no idea why yet. 😔
2020-08-08 20:43:31 +02:00
45 changed files with 427 additions and 494 deletions

5
.flake8 Normal file
View File

@@ -0,0 +1,5 @@
[flake8]
max-line-length = 88
select = C,E,F,W,B,B9
ignore = E203, E501, W503
exclude = __init__.py

View File

@@ -31,9 +31,9 @@ jobs:
env:
GITHUB_CONTEXT: ${{ toJson(github) }}
run: echo "$GITHUB_CONTEXT"
- name: Notify
env:
GITTER_TOKEN: ${{ secrets.GITTER_TOKEN }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
TAG: ${{ github.event.release.name }}
run: bash scripts/notify.sh
# - name: Notify
# env:
# GITTER_TOKEN: ${{ secrets.GITTER_TOKEN }}
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# TAG: ${{ github.event.release.name }}
# run: bash scripts/notify.sh

View File

@@ -16,3 +16,9 @@ In the next sections you will see other options, configurations, and additional
You could still use most of the features in **FastAPI** with the knowledge from the main [Tutorial - User Guide](../tutorial/){.internal-link target=_blank}.
And the next sections assume you already read it, and assume that you know those main ideas.
## TestDriven.io course
If you would like to take an advanced-beginner course to complement this section of the docs, you might want to check: <a href="https://testdriven.io/courses/tdd-fastapi/" class="external-link" target="_blank">Test-Driven Development with FastAPI and Docker</a> by **TestDriven.io**.
They are currently donating 10% of all profits to the development of **FastAPI**. 🎉 😄

View File

@@ -54,9 +54,9 @@ But by using the `secrets.compare_digest()` it will be secure against a type of
But what's a "timing attack"?
Let's imagine an attacker is trying to guess the username and password.
Let's imagine some attackers are trying to guess the username and password.
And that attacker sends a request with a username `johndoe` and a password `love123`.
And they send a request with a username `johndoe` and a password `love123`.
Then the Python code in your application would be equivalent to something like:
@@ -67,7 +67,7 @@ if "johndoe" == "stanleyjobson" and "love123" == "swordfish":
But right at the moment Python compares the first `j` in `johndoe` to the first `s` in `stanleyjobson`, it will return `False`, because it already knows that those two strings are not the same, thinking that "there's no need to waste more computation comparing the rest of the letters". And your application will say "incorrect user or password".
But then the attacker tries with username `stanleyjobsox` and password `love123`.
But then the attackers try with username `stanleyjobsox` and password `love123`.
And your application code does something like:
@@ -78,17 +78,17 @@ if "stanleyjobsox" == "stanleyjobson" and "love123" == "swordfish":
Python will have to compare the whole `stanleyjobso` in both `stanleyjobsox` and `stanleyjobson` before realizing that both strings are not the same. So it will take some extra microseconds to reply back "incorrect user or password".
#### The time to answer helps the attacker
#### The time to answer helps the attackers
At that point, by noticing that the server took some microseconds longer to send the "incorrect user or password" response, the attacker will know that she/he got _something_ right, some of the initial letters were right.
At that point, by noticing that the server took some microseconds longer to send the "incorrect user or password" response, the attackers will know that they got _something_ right, some of the initial letters were right.
And then she/he can try again knowing that it's probably something more similar to `stanleyjobsox` than to `johndoe`.
And then they can try again knowing that it's probably something more similar to `stanleyjobsox` than to `johndoe`.
#### A "professional" attack
Of course, the attacker would not try all this by hand, she/he would write a program to do it, possibly with thousands or millions of tests per second. And would get just one extra correct letter at a time.
Of course, the attackers would not try all this by hand, they would write a program to do it, possibly with thousands or millions of tests per second. And would get just one extra correct letter at a time.
But doing that, in some minutes or hours the attacker would have guessed the correct username and password, with the "help" of our application, just using the time taken to answer.
But doing that, in some minutes or hours the attackers would have guessed the correct username and password, with the "help" of our application, just using the time taken to answer.
#### Fix it with `secrets.compare_digest()`

View File

@@ -137,6 +137,33 @@ With that you can connect the WebSocket and then send and receive messages:
<img src="/img/tutorial/websockets/image05.png">
## Handling disconnections and multiple clients
When a WebSocket connection is closed, the `await websocket.receive_text()` will raise a `WebSocketDisconnect` exception, which you can then catch and handle like in this example.
```Python hl_lines="81-83"
{!../../../docs_src/websockets/tutorial003.py!}
```
To try it out:
* Open the app with several browser tabs.
* Write messages from them.
* Then close one of the tabs.
That will raise the `WebSocketDisconnect` exception, and all the other clients will receive a message like:
```
Client #1596980209979 left the chat
```
!!! tip
The app above is a minimal and simple example to demonstrate how to handle and broadcast messages to several WebSocket connections.
But have in mind that, as everything is handled in memory, in a single list, it will only work while the process is running, and will only work with a single process.
If you need something easy to integrate with FastAPI but that is more robust, supported by Redis, PostgreSQL or others, check <a href="https://github.com/encode/broadcaster" class="external-link" target="_blank">encode/broadcaster</a>.
## More info
To learn more about the options, check Starlette's documentation for:

View File

@@ -2,6 +2,39 @@
## Latest changes
## 0.61.0
### Features
* Add support for injecting `HTTPConnection` (as `Request` and `WebSocket`). Useful for sharing app state in dependencies. PR [#1827](https://github.com/tiangolo/fastapi/pull/1827) by [@nsidnev](https://github.com/nsidnev).
* Export `WebSocketDisconnect` and add example handling WebSocket disconnections to docs. PR [#1822](https://github.com/tiangolo/fastapi/pull/1822) by [@rkbeatss](https://github.com/rkbeatss).
### Breaking Changes
* Require Pydantic > `1.0.0`.
* Remove support for deprecated Pydantic `0.32.2`. This improves maintainability and allows new features.
* In `FastAPI` and `APIRouter`:
* Remove *path operation decorators* related/deprecated parameter `response_model_skip_defaults` (use `response_model_exclude_unset` instead).
* Change *path operation decorators* parameter default for `response_model_exclude` from `set()` to `None` (as is in Pydantic).
* In `encoders.jsonable_encoder`:
* Remove deprecated `skip_defaults`, use instead `exclude_unset`.
* Set default of `exclude` from `set()` to `None` (as is in Pydantic).
* PR [#1862](https://github.com/tiangolo/fastapi/pull/1862).
* In `encoders.jsonable_encoder` remove parameter `sqlalchemy_safe`.
* It was an early hack to allow returning SQLAlchemy models, but it was never documented, and the recommended way is using Pydantic's `orm_mode` as described in the tutorial: [SQL (Relational) Databases](https://fastapi.tiangolo.com/tutorial/sql-databases/).
* PR [#1864](https://github.com/tiangolo/fastapi/pull/1864).
### Docs
* Add link to the course by TestDriven.io: [Test-Driven Development with FastAPI and Docker](https://testdriven.io/courses/tdd-fastapi/). PR [#1860](https://github.com/tiangolo/fastapi/pull/1860).
* Fix empty log message in docs example about handling errors. PR [#1815](https://github.com/tiangolo/fastapi/pull/1815) by [@manlix](https://github.com/manlix).
* Reword text to reduce ambiguity while not being gender-specific. PR [#1824](https://github.com/tiangolo/fastapi/pull/1824) by [@Mause](https://github.com/Mause).
### Internal
* Add Flake8 linting. Original PR [#1774](https://github.com/tiangolo/fastapi/pull/1774) by [@MashhadiNima](https://github.com/MashhadiNima).
* Disable Gitter bot, as it's currently broken, and Gitter's response doesn't show the problem. PR [#1853](https://github.com/tiangolo/fastapi/pull/1853).
## 0.60.2
* Fix typo in docs for query parameters. PR [#1832](https://github.com/tiangolo/fastapi/pull/1832) by [@ycd](https://github.com/ycd).

View File

@@ -2,7 +2,7 @@
You can define background tasks to be run *after* returning a response.
This is useful for operations that need to happen after a request, but that the client doesn't really have to be waiting for the operation to complete before receiving his response.
This is useful for operations that need to happen after a request, but that the client doesn't really have to be waiting for the operation to complete before receiving the response.
This includes, for example:

View File

@@ -47,7 +47,7 @@ In this example, when the client request an item by an ID that doesn't exist, ra
### The resulting response
If the client requests `http://example.com/items/foo` (an `item_id` `"foo"`), he will receive an HTTP status code of 200, and a JSON response of:
If the client requests `http://example.com/items/foo` (an `item_id` `"foo"`), that client will receive an HTTP status code of 200, and a JSON response of:
```JSON
{
@@ -55,7 +55,7 @@ If the client requests `http://example.com/items/foo` (an `item_id` `"foo"`), he
}
```
But if the client requests `http://example.com/items/bar` (a non-existent `item_id` `"bar"`), he will receive an HTTP status code of 404 (the "not found" error), and a JSON response of:
But if the client requests `http://example.com/items/bar` (a non-existent `item_id` `"bar"`), that client will receive an HTTP status code of 404 (the "not found" error), and a JSON response of:
```JSON
{

View File

@@ -85,7 +85,7 @@ But in this case, the same **FastAPI** application will handle the API and the a
So, let's review it from that simplified point of view:
* The user types his `username` and `password` in the frontend, and hits `Enter`.
* The user types the `username` and `password` in the frontend, and hits `Enter`.
* The frontend (running in the user's browser) sends that `username` and `password` to a specific URL in our API (declared with `tokenUrl="token"`).
* The API checks that `username` and `password`, and responds with a "token" (we haven't implemented any of this yet).
* A "token" is just a string with some content that we can use later to verify this user.

View File

@@ -20,7 +20,7 @@ It is not encrypted, so, anyone could recover the information from the contents.
But it's signed. So, when you receive a token that you emitted, you can verify that you actually emitted it.
That way, you can create a token with an expiration of, let's say, 1 week. And then when the user comes back the next day with the token, you know she/he is still logged in to your system.
That way, you can create a token with an expiration of, let's say, 1 week. And then when the user comes back the next day with the token, you know that user is still logged in to your system.
After a week, the token will be expired and the user will not be authorized and will have to sign in again to get a new token. And if the user (or a third party) tried to modify the token to change the expiration, you would be able to discover it, because the signatures would not match.

View File

@@ -11,7 +11,7 @@ app = FastAPI()
@app.exception_handler(StarletteHTTPException)
async def custom_http_exception_handler(request, exc):
print(f"OMG! An HTTP error!: {exc}")
print(f"OMG! An HTTP error!: {repr(exc)}")
return await http_exception_handler(request, exc)

View File

@@ -0,0 +1,83 @@
from typing import List
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
app = FastAPI()
html = """
<!DOCTYPE html>
<html>
<head>
<title>Chat</title>
</head>
<body>
<h1>WebSocket Chat</h1>
<h2>Your ID: <span id="ws-id"></span></h2>
<form action="" onsubmit="sendMessage(event)">
<input type="text" id="messageText" autocomplete="off"/>
<button>Send</button>
</form>
<ul id='messages'>
</ul>
<script>
var client_id = Date.now()
document.querySelector("#ws-id").textContent = client_id;
var ws = new WebSocket(`ws://localhost:8000/ws/${client_id}`);
ws.onmessage = function(event) {
var messages = document.getElementById('messages')
var message = document.createElement('li')
var content = document.createTextNode(event.data)
message.appendChild(content)
messages.appendChild(message)
};
function sendMessage(event) {
var input = document.getElementById("messageText")
ws.send(input.value)
input.value = ''
event.preventDefault()
}
</script>
</body>
</html>
"""
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
@app.get("/")
async def get():
return HTMLResponse(html)
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.send_personal_message(f"You wrote: {data}", websocket)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"Client #{client_id} left the chat")

View File

@@ -1,6 +1,6 @@
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
__version__ = "0.60.2"
__version__ = "0.61.0"
from starlette import status
@@ -22,4 +22,4 @@ from .param_functions import (
from .requests import Request
from .responses import Response
from .routing import APIRouter
from .websockets import WebSocket
from .websockets import WebSocket, WebSocketDisconnect

View File

@@ -16,7 +16,6 @@ from fastapi.openapi.docs import (
)
from fastapi.openapi.utils import get_openapi
from fastapi.params import Depends
from fastapi.utils import warning_response_model_skip_defaults_deprecated
from starlette.applications import Starlette
from starlette.datastructures import State
from starlette.exceptions import HTTPException
@@ -198,9 +197,8 @@ class FastAPI(Starlette):
methods: Optional[List[str]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -208,8 +206,6 @@ class FastAPI(Starlette):
response_class: Optional[Type[Response]] = None,
name: Optional[str] = None,
) -> None:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
self.router.add_api_route(
path,
endpoint=endpoint,
@@ -227,9 +223,7 @@ class FastAPI(Starlette):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -253,9 +247,8 @@ class FastAPI(Starlette):
methods: Optional[List[str]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -263,9 +256,6 @@ class FastAPI(Starlette):
response_class: Optional[Type[Response]] = None,
name: Optional[str] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
def decorator(func: Callable) -> Callable:
self.router.add_api_route(
path,
@@ -284,9 +274,7 @@ class FastAPI(Starlette):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -344,9 +332,8 @@ class FastAPI(Starlette):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -355,8 +342,6 @@ class FastAPI(Starlette):
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.router.get(
path,
response_model=response_model,
@@ -372,9 +357,7 @@ class FastAPI(Starlette):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -398,9 +381,8 @@ class FastAPI(Starlette):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -409,8 +391,6 @@ class FastAPI(Starlette):
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.router.put(
path,
response_model=response_model,
@@ -426,9 +406,7 @@ class FastAPI(Starlette):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -452,9 +430,8 @@ class FastAPI(Starlette):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -463,8 +440,6 @@ class FastAPI(Starlette):
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.router.post(
path,
response_model=response_model,
@@ -480,9 +455,7 @@ class FastAPI(Starlette):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -506,9 +479,8 @@ class FastAPI(Starlette):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -517,8 +489,6 @@ class FastAPI(Starlette):
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.router.delete(
path,
response_model=response_model,
@@ -534,9 +504,7 @@ class FastAPI(Starlette):
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
operation_id=operation_id,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -560,9 +528,8 @@ class FastAPI(Starlette):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -571,8 +538,6 @@ class FastAPI(Starlette):
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.router.options(
path,
response_model=response_model,
@@ -588,9 +553,7 @@ class FastAPI(Starlette):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -614,9 +577,8 @@ class FastAPI(Starlette):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -625,8 +587,6 @@ class FastAPI(Starlette):
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.router.head(
path,
response_model=response_model,
@@ -642,9 +602,7 @@ class FastAPI(Starlette):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -668,9 +626,8 @@ class FastAPI(Starlette):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -679,8 +636,6 @@ class FastAPI(Starlette):
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.router.patch(
path,
response_model=response_model,
@@ -696,9 +651,7 @@ class FastAPI(Starlette):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -722,9 +675,8 @@ class FastAPI(Starlette):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -733,8 +685,6 @@ class FastAPI(Starlette):
name: Optional[str] = None,
callbacks: Optional[List[routing.APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.router.trace(
path,
response_model=response_model,
@@ -750,9 +700,7 @@ class FastAPI(Starlette):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,

View File

@@ -1,12 +1,7 @@
from typing import Callable, List, Optional, Sequence
from fastapi.security.base import SecurityBase
try:
from pydantic.fields import ModelField
except ImportError: # pragma: nocover
# TODO: remove when removing support for Pydantic < 1.0.0
from pydantic.fields import Field as ModelField # type: ignore
from pydantic.fields import ModelField
param_supported_types = (str, int, float, bool)
@@ -34,6 +29,7 @@ class Dependant:
call: Optional[Callable] = None,
request_param_name: Optional[str] = None,
websocket_param_name: Optional[str] = None,
http_connection_param_name: Optional[str] = None,
response_param_name: Optional[str] = None,
background_tasks_param_name: Optional[str] = None,
security_scopes_param_name: Optional[str] = None,
@@ -50,6 +46,7 @@ class Dependant:
self.security_requirements = security_schemes or []
self.request_param_name = request_param_name
self.websocket_param_name = websocket_param_name
self.http_connection_param_name = http_connection_param_name
self.response_param_name = response_param_name
self.background_tasks_param_name = background_tasks_param_name
self.security_scopes = security_scopes

View File

@@ -28,58 +28,31 @@ from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.utils import (
PYDANTIC_1,
create_response_field,
get_field_info,
get_path_param_names,
)
from pydantic import BaseConfig, BaseModel, create_model
from fastapi.utils import create_response_field, get_path_param_names
from pydantic import BaseModel, create_model
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import (
SHAPE_LIST,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
SHAPE_TUPLE,
SHAPE_TUPLE_ELLIPSIS,
FieldInfo,
ModelField,
Required,
)
from pydantic.schema import get_annotation_from_field_info
from pydantic.typing import ForwardRef, evaluate_forwardref
from pydantic.utils import lenient_issubclass
from starlette.background import BackgroundTasks
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
from starlette.requests import Request
from starlette.requests import HTTPConnection, Request
from starlette.responses import Response
from starlette.websockets import WebSocket
try:
from pydantic.fields import (
SHAPE_LIST,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
SHAPE_TUPLE,
SHAPE_TUPLE_ELLIPSIS,
FieldInfo,
ModelField,
Required,
)
from pydantic.schema import get_annotation_from_field_info
from pydantic.typing import ForwardRef, evaluate_forwardref
except ImportError: # pragma: nocover
# TODO: remove when removing support for Pydantic < 1.0.0
from pydantic import Schema as FieldInfo # type: ignore
from pydantic.fields import Field as ModelField # type: ignore
from pydantic.fields import Required, Shape # type: ignore
from pydantic.schema import get_annotation_from_schema # type: ignore
from pydantic.utils import ForwardRef, evaluate_forwardref # type: ignore
SHAPE_LIST = Shape.LIST
SHAPE_SEQUENCE = Shape.SEQUENCE
SHAPE_SET = Shape.SET
SHAPE_SINGLETON = Shape.SINGLETON
SHAPE_TUPLE = Shape.TUPLE
SHAPE_TUPLE_ELLIPSIS = Shape.TUPLE_ELLIPS
def get_annotation_from_field_info(
annotation: Any, field_info: FieldInfo, field_name: str
) -> Type[Any]:
return get_annotation_from_schema(annotation, field_info)
sequence_shapes = {
SHAPE_LIST,
SHAPE_SET,
@@ -113,7 +86,7 @@ multipart_incorrect_install_error = (
def check_file_field(field: ModelField) -> None:
field_info = get_field_info(field)
field_info = field.field_info
if isinstance(field_info, params.Form):
try:
# __version__ is available in both multiparts, and can be mocked
@@ -239,7 +212,7 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
def is_scalar_field(field: ModelField) -> bool:
field_info = get_field_info(field)
field_info = field.field_info
if not (
field.shape == SHAPE_SINGLETON
and not lenient_issubclass(field.type_, BaseModel)
@@ -354,7 +327,7 @@ def get_dependant(
) and is_scalar_sequence_field(param_field):
add_param_to_fields(field=param_field, dependant=dependant)
else:
field_info = get_field_info(param_field)
field_info = param_field.field_info
assert isinstance(
field_info, params.Body
), f"Param: {param_field.name} can only be a request body, using Body(...)"
@@ -371,6 +344,9 @@ def add_non_field_param_to_dependency(
elif lenient_issubclass(param.annotation, WebSocket):
dependant.websocket_param_name = param.name
return True
elif lenient_issubclass(param.annotation, HTTPConnection):
dependant.http_connection_param_name = param.name
return True
elif lenient_issubclass(param.annotation, Response):
dependant.response_param_name = param.name
return True
@@ -427,16 +403,13 @@ def get_param_field(
)
field.required = required
if not had_schema and not is_scalar_field(field=field):
if PYDANTIC_1:
field.field_info = params.Body(field_info.default)
else:
field.schema = params.Body(field_info.default) # type: ignore # pragma: nocover
field.field_info = params.Body(field_info.default)
return field
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
field_info = cast(params.Param, get_field_info(field))
field_info = cast(params.Param, field.field_info)
if field_info.in_ == params.ParamTypes.path:
dependant.path_params.append(field)
elif field_info.in_ == params.ParamTypes.query:
@@ -607,6 +580,8 @@ async def solve_dependencies(
)
values.update(body_values)
errors.extend(body_errors)
if dependant.http_connection_param_name:
values[dependant.http_connection_param_name] = request
if dependant.request_param_name and isinstance(request, Request):
values[dependant.request_param_name] = request
elif dependant.websocket_param_name and isinstance(request, WebSocket):
@@ -637,26 +612,17 @@ def request_params_to_args(
value = received_params.getlist(field.alias) or field.default
else:
value = received_params.get(field.alias)
field_info = get_field_info(field)
field_info = field.field_info
assert isinstance(
field_info, params.Param
), "Params must be subclasses of Param"
if value is None:
if field.required:
if PYDANTIC_1:
errors.append(
ErrorWrapper(
MissingError(), loc=(field_info.in_.value, field.alias)
)
)
else: # pragma: nocover
errors.append(
ErrorWrapper( # type: ignore
MissingError(),
loc=(field_info.in_.value, field.alias),
config=BaseConfig,
)
errors.append(
ErrorWrapper(
MissingError(), loc=(field_info.in_.value, field.alias)
)
)
else:
values[field.name] = deepcopy(field.default)
continue
@@ -680,7 +646,7 @@ async def request_body_to_args(
errors = []
if required_params:
field = required_params[0]
field_info = get_field_info(field)
field_info = field.field_info
embed = getattr(field_info, "embed", None)
field_alias_omitted = len(required_params) == 1 and not embed
if field_alias_omitted:
@@ -747,12 +713,7 @@ async def request_body_to_args(
def get_missing_field_error(loc: Tuple[str, ...]) -> ErrorWrapper:
if PYDANTIC_1:
missing_field_error = ErrorWrapper(MissingError(), loc=loc)
else: # pragma: no cover
missing_field_error = ErrorWrapper( # type: ignore
MissingError(), loc=loc, config=BaseConfig,
)
missing_field_error = ErrorWrapper(MissingError(), loc=loc)
return missing_field_error
@@ -770,7 +731,7 @@ def get_schema_compatible_field(*, field: ModelField) -> ModelField:
default=field.default,
required=field.required,
alias=field.alias,
field_info=get_field_info(field),
field_info=field.field_info,
)
return out_field
@@ -780,7 +741,7 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
if not flat_dependant.body_params:
return None
first_param = flat_dependant.body_params[0]
field_info = get_field_info(first_param)
field_info = first_param.field_info
embed = getattr(field_info, "embed", None)
body_param_names_set = {param.name for param in flat_dependant.body_params}
if len(body_param_names_set) == 1 and not embed:
@@ -791,7 +752,7 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
# in case a sub-dependency is evaluated with a single unique body field
# That is combined (embedded) with other body fields
for param in flat_dependant.body_params:
setattr(get_field_info(param), "embed", True)
setattr(param.field_info, "embed", True)
model_name = "Body_" + name
BodyModel = create_model(model_name)
for f in flat_dependant.body_params:
@@ -799,21 +760,17 @@ def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
required = any(True for f in flat_dependant.body_params if f.required)
BodyFieldInfo_kwargs: Dict[str, Any] = dict(default=None)
if any(
isinstance(get_field_info(f), params.File) for f in flat_dependant.body_params
):
if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
BodyFieldInfo: Type[params.Body] = params.File
elif any(
isinstance(get_field_info(f), params.Form) for f in flat_dependant.body_params
):
elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
BodyFieldInfo = params.Form
else:
BodyFieldInfo = params.Body
body_param_media_types = [
getattr(get_field_info(f), "media_type")
getattr(f.field_info, "media_type")
for f in flat_dependant.body_params
if isinstance(get_field_info(f), params.Body)
if isinstance(f.field_info, params.Body)
]
if len(set(body_param_media_types)) == 1:
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]

View File

@@ -4,8 +4,6 @@ from pathlib import PurePath
from types import GeneratorType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from fastapi.logger import logger
from fastapi.utils import PYDANTIC_1
from pydantic import BaseModel
from pydantic.json import ENCODERS_BY_TYPE
@@ -28,21 +26,13 @@ encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
def jsonable_encoder(
obj: Any,
include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
exclude: Union[SetIntStr, DictIntStrAny] = set(),
exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
by_alias: bool = True,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
custom_encoder: dict = {},
sqlalchemy_safe: bool = True,
) -> Any:
if skip_defaults is not None:
logger.warning( # pragma: nocover
"skip_defaults in jsonable_encoder has been deprecated in favor of "
"exclude_unset to keep in line with Pydantic v1, support for it will be "
"removed soon."
)
if include is not None and not isinstance(include, set):
include = set(include)
if exclude is not None and not isinstance(exclude, set):
@@ -51,24 +41,14 @@ def jsonable_encoder(
encoder = getattr(obj.__config__, "json_encoders", {})
if custom_encoder:
encoder.update(custom_encoder)
if PYDANTIC_1:
obj_dict = obj.dict(
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=bool(exclude_unset or skip_defaults),
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
else: # pragma: nocover
if exclude_defaults:
raise ValueError("Cannot use exclude_defaults")
obj_dict = obj.dict(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=bool(exclude_unset or skip_defaults),
)
obj_dict = obj.dict(
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]
return jsonable_encoder(
@@ -76,7 +56,6 @@ def jsonable_encoder(
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
custom_encoder=encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if isinstance(obj, Enum):
return obj.value
@@ -87,14 +66,8 @@ def jsonable_encoder(
if isinstance(obj, dict):
encoded_dict = {}
for key, value in obj.items():
if (
(
not sqlalchemy_safe
or (not isinstance(key, str))
or (not key.startswith("_sa"))
)
and (value is not None or not exclude_none)
and ((include and key in include) or key not in exclude)
if (value is not None or not exclude_none) and (
(include and key in include) or not exclude or key not in exclude
):
encoded_key = jsonable_encoder(
key,
@@ -102,7 +75,6 @@ def jsonable_encoder(
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_value = jsonable_encoder(
value,
@@ -110,7 +82,6 @@ def jsonable_encoder(
exclude_unset=exclude_unset,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
encoded_dict[encoded_key] = encoded_value
return encoded_dict
@@ -127,7 +98,6 @@ def jsonable_encoder(
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
)
return encoded_list
@@ -163,5 +133,4 @@ def jsonable_encoder(
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)

View File

@@ -1,11 +1,8 @@
from typing import Any, Dict, Optional, Sequence
from fastapi.utils import PYDANTIC_1
from pydantic import ValidationError, create_model
from pydantic.error_wrappers import ErrorList
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.requests import Request
from starlette.websockets import WebSocket
class HTTPException(StarletteHTTPException):
@@ -32,15 +29,9 @@ class FastAPIError(RuntimeError):
class RequestValidationError(ValidationError):
def __init__(self, errors: Sequence[ErrorList], *, body: Any = None) -> None:
self.body = body
if PYDANTIC_1:
super().__init__(errors, RequestErrorModel)
else:
super().__init__(errors, Request) # type: ignore # pragma: nocover
super().__init__(errors, RequestErrorModel)
class WebSocketRequestValidationError(ValidationError):
def __init__(self, errors: Sequence[ErrorList]) -> None:
if PYDANTIC_1:
super().__init__(errors, WebSocketErrorModel)
else:
super().__init__(errors, WebSocket) # type: ignore # pragma: nocover
super().__init__(errors, WebSocketErrorModel)

View File

@@ -2,24 +2,13 @@ from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from fastapi.logger import logger
from pydantic import BaseModel
try:
from pydantic import AnyUrl, Field
except ImportError: # pragma: nocover
# TODO: remove when removing support for Pydantic < 1.0.0
from pydantic import Schema as Field # type: ignore
from pydantic import UrlStr as AnyUrl # type: ignore
from pydantic import AnyUrl, BaseModel, Field
try:
import email_validator
assert email_validator # make autoflake ignore the unused import
try:
from pydantic import EmailStr
except ImportError: # pragma: nocover
# TODO: remove when removing support for Pydantic < 1.0.0
from pydantic.types import EmailStr # type: ignore
from pydantic import EmailStr
except ImportError: # pragma: no cover
class EmailStr(str): # type: ignore

View File

@@ -16,10 +16,10 @@ from fastapi.params import Body, Param
from fastapi.utils import (
deep_dict_update,
generate_operation_id_for_path,
get_field_info,
get_model_definitions,
)
from pydantic import BaseModel
from pydantic.fields import ModelField
from pydantic.schema import (
field_schema,
get_flat_models_from_fields,
@@ -30,12 +30,6 @@ from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
try:
from pydantic.fields import ModelField
except ImportError: # pragma: nocover
# TODO: remove when removing support for Pydantic < 1.0.0
from pydantic.fields import Field as ModelField # type: ignore
validation_error_definition = {
"title": "ValidationError",
"type": "object",
@@ -91,7 +85,7 @@ def get_openapi_operation_parameters(
) -> List[Dict[str, Any]]:
parameters = []
for param in all_route_params:
field_info = get_field_info(param)
field_info = param.field_info
field_info = cast(Param, field_info)
# ignore mypy error until enum schemas are released
parameter = {
@@ -122,7 +116,7 @@ def get_openapi_operation_request_body(
body_schema, _, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX # type: ignore
)
field_info = cast(Body, get_field_info(body_field))
field_info = cast(Body, body_field.field_info)
request_media_type = field_info.media_type
required = body_field.required
request_body_oai: Dict[str, Any] = {}

View File

@@ -1,11 +1,7 @@
from enum import Enum
from typing import Any, Callable, Optional, Sequence
try:
from pydantic.fields import FieldInfo
except ImportError: # pragma: nocover
# TODO: remove when removing support for Pydantic < 1.0.0
from pydantic import Schema as FieldInfo # type: ignore
from pydantic.fields import FieldInfo
class ParamTypes(Enum):

View File

@@ -1 +1,2 @@
from starlette.requests import Request # noqa
from starlette.requests import HTTPConnection as HTTPConnection # noqa: F401
from starlette.requests import Request as Request # noqa: F401

View File

@@ -16,15 +16,13 @@ from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.openapi.constants import STATUS_CODES_WITH_NO_BODY
from fastapi.utils import (
PYDANTIC_1,
create_cloned_field,
create_response_field,
generate_operation_id_for_path,
get_field_info,
warning_response_model_skip_defaults_deprecated,
)
from pydantic import BaseModel
from pydantic.error_wrappers import ErrorWrapper, ValidationError
from pydantic.fields import ModelField
from starlette import routing
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
@@ -41,13 +39,6 @@ from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.types import ASGIApp
from starlette.websockets import WebSocket
try:
from pydantic.fields import FieldInfo, ModelField
except ImportError: # pragma: nocover
# TODO: remove when removing support for Pydantic < 1.0.0
from pydantic import Schema as FieldInfo # type: ignore
from pydantic.fields import Field as ModelField # type: ignore
def _prepare_response_content(
res: Any,
@@ -57,17 +48,12 @@ def _prepare_response_content(
exclude_none: bool = False,
) -> Any:
if isinstance(res, BaseModel):
if PYDANTIC_1:
return res.dict(
by_alias=True,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
else:
return res.dict(
by_alias=True, skip_defaults=exclude_unset,
) # pragma: nocover
return res.dict(
by_alias=True,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
elif isinstance(res, list):
return [
_prepare_response_content(
@@ -96,7 +82,7 @@ async def serialize_response(
field: Optional[ModelField] = None,
response_content: Any,
include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
exclude: Union[SetIntStr, DictIntStrAny] = set(),
exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
@@ -156,7 +142,7 @@ def get_request_handler(
response_class: Type[Response] = JSONResponse,
response_field: Optional[ModelField] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
@@ -165,7 +151,7 @@ def get_request_handler(
) -> Callable:
assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(get_field_info(body_field), params.Form)
is_body_form = body_field and isinstance(body_field.field_info, params.Form)
async def app(request: Request) -> Response:
try:
@@ -285,7 +271,7 @@ class APIRoute(routing.Route):
methods: Optional[Union[Set[str], List[str]]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
@@ -438,9 +424,8 @@ class APIRouter(routing.Router):
methods: Optional[Union[Set[str], List[str]]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -450,8 +435,6 @@ class APIRouter(routing.Router):
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> None:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
route_class = route_class_override or self.route_class
route = route_class(
path,
@@ -470,9 +453,7 @@ class APIRouter(routing.Router):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -499,9 +480,8 @@ class APIRouter(routing.Router):
methods: Optional[List[str]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -510,9 +490,6 @@ class APIRouter(routing.Router):
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
def decorator(func: Callable) -> Callable:
self.add_api_route(
path,
@@ -531,9 +508,7 @@ class APIRouter(routing.Router):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -654,9 +629,8 @@ class APIRouter(routing.Router):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -665,8 +639,6 @@ class APIRouter(routing.Router):
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.api_route(
path=path,
response_model=response_model,
@@ -683,9 +655,7 @@ class APIRouter(routing.Router):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -709,9 +679,8 @@ class APIRouter(routing.Router):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -720,8 +689,6 @@ class APIRouter(routing.Router):
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.api_route(
path=path,
response_model=response_model,
@@ -738,9 +705,7 @@ class APIRouter(routing.Router):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -764,9 +729,8 @@ class APIRouter(routing.Router):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -775,8 +739,6 @@ class APIRouter(routing.Router):
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.api_route(
path=path,
response_model=response_model,
@@ -793,9 +755,7 @@ class APIRouter(routing.Router):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -819,9 +779,8 @@ class APIRouter(routing.Router):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -830,8 +789,6 @@ class APIRouter(routing.Router):
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.api_route(
path=path,
response_model=response_model,
@@ -848,9 +805,7 @@ class APIRouter(routing.Router):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -874,9 +829,8 @@ class APIRouter(routing.Router):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -885,8 +839,6 @@ class APIRouter(routing.Router):
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.api_route(
path=path,
response_model=response_model,
@@ -903,9 +855,7 @@ class APIRouter(routing.Router):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -929,9 +879,8 @@ class APIRouter(routing.Router):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -940,8 +889,6 @@ class APIRouter(routing.Router):
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.api_route(
path=path,
response_model=response_model,
@@ -958,9 +905,7 @@ class APIRouter(routing.Router):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -984,9 +929,8 @@ class APIRouter(routing.Router):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -995,8 +939,6 @@ class APIRouter(routing.Router):
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.api_route(
path=path,
response_model=response_model,
@@ -1013,9 +955,7 @@ class APIRouter(routing.Router):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
@@ -1039,9 +979,8 @@ class APIRouter(routing.Router):
deprecated: Optional[bool] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Union[SetIntStr, DictIntStrAny] = set(),
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_skip_defaults: Optional[bool] = None,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
@@ -1050,8 +989,7 @@ class APIRouter(routing.Router):
name: Optional[str] = None,
callbacks: Optional[List[APIRoute]] = None,
) -> Callable:
if response_model_skip_defaults is not None:
warning_response_model_skip_defaults_deprecated() # pragma: nocover
return self.api_route(
path=path,
response_model=response_model,
@@ -1068,9 +1006,7 @@ class APIRouter(routing.Router):
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=bool(
response_model_exclude_unset or response_model_skip_defaults
),
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,

View File

@@ -27,7 +27,7 @@ class OAuth2PasswordRequestForm:
print(data.client_secret)
return data
It creates the following Form request parameters in your endpoint:
grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password".
@@ -77,7 +77,7 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
print(data.client_secret)
return data
It creates the following Form request parameters in your endpoint:
grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password".

View File

@@ -5,49 +5,13 @@ from enum import Enum
from typing import Any, Dict, Optional, Set, Type, Union, cast
import fastapi
from fastapi.logger import logger
from fastapi.openapi.constants import REF_PREFIX
from pydantic import BaseConfig, BaseModel, create_model
from pydantic.class_validators import Validator
from pydantic.fields import FieldInfo, ModelField, UndefinedType
from pydantic.schema import model_process_schema
from pydantic.utils import lenient_issubclass
try:
from pydantic.fields import FieldInfo, ModelField, UndefinedType
PYDANTIC_1 = True
except ImportError: # pragma: nocover
# TODO: remove when removing support for Pydantic < 1.0.0
from pydantic import Schema as FieldInfo # type: ignore
from pydantic.fields import Field as ModelField # type: ignore
class UndefinedType: # type: ignore
def __repr__(self) -> str:
return "PydanticUndefined"
logger.warning(
"Pydantic versions < 1.0.0 are deprecated in FastAPI and support will be "
"removed soon."
)
PYDANTIC_1 = False
# TODO: remove when removing support for Pydantic < 1.0.0
def get_field_info(field: ModelField) -> FieldInfo:
if PYDANTIC_1:
return field.field_info # type: ignore
else:
return field.schema # type: ignore # pragma: nocover
# TODO: remove when removing support for Pydantic < 1.0.0
def warning_response_model_skip_defaults_deprecated() -> None:
logger.warning( # pragma: nocover
"response_model_skip_defaults has been deprecated in favor of "
"response_model_exclude_unset to keep in line with Pydantic v1, support for "
"it will be removed soon."
)
def get_model_definitions(
*,
@@ -98,10 +62,7 @@ def create_response_field(
)
try:
if PYDANTIC_1:
return response_field(field_info=field_info)
else: # pragma: nocover
return response_field(schema=field_info)
return response_field(field_info=field_info)
except RuntimeError:
raise fastapi.exceptions.FastAPIError(
f"Invalid args for response field! Hint: check that {type_} is a valid pydantic field type"
@@ -137,10 +98,7 @@ def create_cloned_field(
new_field.default = field.default
new_field.required = field.required
new_field.model_config = field.model_config
if PYDANTIC_1:
new_field.field_info = field.field_info
else: # pragma: nocover
new_field.schema = field.schema # type: ignore
new_field.field_info = field.field_info
new_field.allow_none = field.allow_none
new_field.validate_always = field.validate_always
if field.sub_fields:
@@ -153,19 +111,11 @@ def create_cloned_field(
field.key_field, cloned_types=cloned_types
)
new_field.validators = field.validators
if PYDANTIC_1:
new_field.pre_validators = field.pre_validators
new_field.post_validators = field.post_validators
else: # pragma: nocover
new_field.whole_pre_validators = field.whole_pre_validators # type: ignore
new_field.whole_post_validators = field.whole_post_validators # type: ignore
new_field.pre_validators = field.pre_validators
new_field.post_validators = field.post_validators
new_field.parse_json = field.parse_json
new_field.shape = field.shape
try:
new_field.populate_validators()
except AttributeError: # pragma: nocover
# TODO: remove when removing support for Pydantic < 1.0.0
new_field._populate_validators() # type: ignore
new_field.populate_validators()
return new_field

View File

@@ -33,7 +33,7 @@ classifiers = [
]
requires = [
"starlette ==0.13.6",
"pydantic >=0.32.2,<2.0.0"
"pydantic >=1.0.0,<2.0.0"
]
description-file = "README.md"
requires-python = ">=3.6"
@@ -47,6 +47,7 @@ test = [
"pytest-cov ==2.10.0",
"pytest-asyncio >=0.14.0,<0.15.0",
"mypy ==0.782",
"flake8 >=3.8.3,<4.0.0",
"black ==19.10b0",
"isort >=5.0.6,<6.0.0",
"requests >=2.24.0,<3.0.0",

View File

@@ -4,5 +4,6 @@ set -e
set -x
mypy fastapi
flake8 fastapi tests
black fastapi tests --check
isort fastapi tests docs_src scripts --check-only

View File

@@ -168,7 +168,7 @@ def get_query_type_optional(query: Optional[int] = None):
@app.get("/query/int/default")
def get_query_type_optional(query: int = 10):
def get_query_type_int_default(query: int = 10):
return f"foo bar {query}"

View File

@@ -976,8 +976,8 @@ openapi_schema = {
},
},
},
"summary": "Get Query Type Optional",
"operationId": "get_query_type_optional_query_int_default_get",
"summary": "Get Query Type Int Default",
"operationId": "get_query_type_int_default_query_int_default_get",
"parameters": [
{
"required": False,
@@ -1144,7 +1144,7 @@ def test_swagger_ui():
assert response.headers["content-type"] == "text/html; charset=utf-8"
assert "swagger-ui-dist" in response.text
assert (
f"oauth2RedirectUrl: window.location.origin + '/docs/oauth2-redirect'"
"oauth2RedirectUrl: window.location.origin + '/docs/oauth2-redirect'"
in response.text
)

View File

@@ -59,12 +59,14 @@ async def get_callable_gen_dependency(value: str = Depends(callable_gen_dependen
@app.get("/async-callable-dependency")
async def get_callable_dependency(value: str = Depends(async_callable_dependency)):
async def get_async_callable_dependency(
value: str = Depends(async_callable_dependency),
):
return value
@app.get("/async-callable-gen-dependency")
async def get_callable_gen_dependency(
async def get_async_callable_gen_dependency(
value: str = Depends(async_callable_gen_dependency),
):
return value

View File

@@ -204,19 +204,19 @@ client = TestClient(app)
def test_async_state():
assert state["/async"] == f"asyncgen not started"
assert state["/async"] == "asyncgen not started"
response = client.get("/async")
assert response.status_code == 200, response.text
assert response.json() == f"asyncgen started"
assert state["/async"] == f"asyncgen completed"
assert response.json() == "asyncgen started"
assert state["/async"] == "asyncgen completed"
def test_sync_state():
assert state["/sync"] == f"generator not started"
assert state["/sync"] == "generator not started"
response = client.get("/sync")
assert response.status_code == 200, response.text
assert response.json() == f"generator started"
assert state["/sync"] == f"generator completed"
assert response.json() == "generator started"
assert state["/sync"] == "generator completed"
def test_async_raise_other():
@@ -281,15 +281,15 @@ def test_sync_raise():
def test_sync_async_state():
response = client.get("/sync_async")
assert response.status_code == 200, response.text
assert response.json() == f"asyncgen started"
assert state["/async"] == f"asyncgen completed"
assert response.json() == "asyncgen started"
assert state["/async"] == "asyncgen completed"
def test_sync_sync_state():
response = client.get("/sync_sync")
assert response.status_code == 200, response.text
assert response.json() == f"generator started"
assert state["/sync"] == f"generator completed"
assert response.json() == "generator started"
assert state["/sync"] == "generator completed"
def test_sync_async_raise_other():

View File

@@ -0,0 +1,39 @@
from fastapi import Depends, FastAPI
from fastapi.requests import HTTPConnection
from fastapi.testclient import TestClient
from starlette.websockets import WebSocket
app = FastAPI()
app.state.value = 42
async def extract_value_from_http_connection(conn: HTTPConnection):
return conn.app.state.value
@app.get("/http")
async def get_value_by_http(value: int = Depends(extract_value_from_http_connection)):
return value
@app.websocket("/ws")
async def get_value_by_ws(
websocket: WebSocket, value: int = Depends(extract_value_from_http_connection)
):
await websocket.accept()
await websocket.send_json(value)
await websocket.close()
client = TestClient(app)
def test_value_extracting_by_http():
response = client.get("/http")
assert response.status_code == 200
assert response.json() == 42
def test_value_extracting_by_ws():
with client.websocket_connect("/ws") as websocket:
assert websocket.receive_json() == 42

View File

@@ -5,13 +5,7 @@ from typing import Optional
import pytest
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, ValidationError, create_model
try:
from pydantic import Field
except ImportError: # pragma: nocover
# TODO: remove when removing support for Pydantic < 1.0.0
from pydantic import Schema as Field
from pydantic import BaseModel, Field, ValidationError, create_model
class Person:
@@ -134,7 +128,7 @@ def test_encode_model_with_config():
def test_encode_model_with_alias_raises():
with pytest.raises(ValidationError):
model = ModelWithAlias(foo="Bar")
ModelWithAlias(foo="Bar")
def test_encode_model_with_alias():

View File

@@ -8,6 +8,7 @@ app = FastAPI()
media_type = "application/vnd.api+json"
# NOTE: These are not valid JSON:API resources
# but they are fine for testing requestBody with custom media_type
class Product(BaseModel):

View File

@@ -54,17 +54,17 @@ def by_alias_list():
@app.get("/no-alias/dict", response_model=ModelNoAlias)
def by_alias_dict():
def no_alias_dict():
return {"name": "Foo"}
@app.get("/no-alias/model", response_model=ModelNoAlias)
def by_alias_model():
def no_alias_model():
return ModelNoAlias(name="Foo")
@app.get("/no-alias/list", response_model=List[ModelNoAlias])
def by_alias_list():
def no_alias_list():
return [{"name": "Foo"}, {"name": "Bar"}]
@@ -178,8 +178,8 @@ openapi_schema = {
},
"/no-alias/dict": {
"get": {
"summary": "By Alias Dict",
"operationId": "by_alias_dict_no_alias_dict_get",
"summary": "No Alias Dict",
"operationId": "no_alias_dict_no_alias_dict_get",
"responses": {
"200": {
"description": "Successful Response",
@@ -194,8 +194,8 @@ openapi_schema = {
},
"/no-alias/model": {
"get": {
"summary": "By Alias Model",
"operationId": "by_alias_model_no_alias_model_get",
"summary": "No Alias Model",
"operationId": "no_alias_model_no_alias_model_get",
"responses": {
"200": {
"description": "Successful Response",
@@ -210,15 +210,15 @@ openapi_schema = {
},
"/no-alias/list": {
"get": {
"summary": "By Alias List",
"operationId": "by_alias_list_no_alias_list_get",
"summary": "No Alias List",
"operationId": "no_alias_list_no_alias_list_get",
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"title": "Response By Alias List No Alias List Get",
"title": "Response No Alias List No Alias List Get",
"type": "array",
"items": {
"$ref": "#/components/schemas/ModelNoAlias"

View File

@@ -28,7 +28,7 @@ def get_current_user(oauth_header: "str" = Security(reusable_oauth2)):
@app.post("/login")
# Here we use string annotations to test them
def read_current_user(form_data: "OAuth2PasswordRequestFormStrict" = Depends()):
def login(form_data: "OAuth2PasswordRequestFormStrict" = Depends()):
return form_data
@@ -62,13 +62,13 @@ openapi_schema = {
},
},
},
"summary": "Read Current User",
"operationId": "read_current_user_login_post",
"summary": "Login",
"operationId": "login_login_post",
"requestBody": {
"content": {
"application/x-www-form-urlencoded": {
"schema": {
"$ref": "#/components/schemas/Body_read_current_user_login_post"
"$ref": "#/components/schemas/Body_login_login_post"
}
}
},
@@ -92,8 +92,8 @@ openapi_schema = {
},
"components": {
"schemas": {
"Body_read_current_user_login_post": {
"title": "Body_read_current_user_login_post",
"Body_login_login_post": {
"title": "Body_login_login_post",
"required": ["grant_type", "username", "password"],
"type": "object",
"properties": {

View File

@@ -31,12 +31,12 @@ def get_current_user(oauth_header: Optional[str] = Security(reusable_oauth2)):
@app.post("/login")
def read_current_user(form_data: OAuth2PasswordRequestFormStrict = Depends()):
def login(form_data: OAuth2PasswordRequestFormStrict = Depends()):
return form_data
@app.get("/users/me")
def read_current_user(current_user: Optional[User] = Depends(get_current_user)):
def read_users_me(current_user: Optional[User] = Depends(get_current_user)):
if current_user is None:
return {"msg": "Create an account first"}
return current_user
@@ -66,13 +66,13 @@ openapi_schema = {
},
},
},
"summary": "Read Current User",
"operationId": "read_current_user_login_post",
"summary": "Login",
"operationId": "login_login_post",
"requestBody": {
"content": {
"application/x-www-form-urlencoded": {
"schema": {
"$ref": "#/components/schemas/Body_read_current_user_login_post"
"$ref": "#/components/schemas/Body_login_login_post"
}
}
},
@@ -88,16 +88,16 @@ openapi_schema = {
"content": {"application/json": {"schema": {}}},
}
},
"summary": "Read Current User",
"operationId": "read_current_user_users_me_get",
"summary": "Read Users Me",
"operationId": "read_users_me_users_me_get",
"security": [{"OAuth2": []}],
}
},
},
"components": {
"schemas": {
"Body_read_current_user_login_post": {
"title": "Body_read_current_user_login_post",
"Body_login_login_post": {
"title": "Body_login_login_post",
"required": ["grant_type", "username", "password"],
"type": "object",
"properties": {

View File

@@ -30,14 +30,14 @@ class ModelDefaults(BaseModel):
@app.get("/", response_model=Model, response_model_exclude_unset=True)
def get() -> ModelSubclass:
def get_root() -> ModelSubclass:
return ModelSubclass(sub={}, y=1, z=0)
@app.get(
"/exclude_unset", response_model=ModelDefaults, response_model_exclude_unset=True
)
def get() -> ModelDefaults:
def get_exclude_unset() -> ModelDefaults:
return ModelDefaults(x=None, y="y")
@@ -46,14 +46,14 @@ def get() -> ModelDefaults:
response_model=ModelDefaults,
response_model_exclude_defaults=True,
)
def get() -> ModelDefaults:
def get_exclude_defaults() -> ModelDefaults:
return ModelDefaults(x=None, y="y")
@app.get(
"/exclude_none", response_model=ModelDefaults, response_model_exclude_none=True
)
def get() -> ModelDefaults:
def get_exclude_none() -> ModelDefaults:
return ModelDefaults(x=None, y="y")
@@ -63,7 +63,7 @@ def get() -> ModelDefaults:
response_model_exclude_unset=True,
response_model_exclude_none=True,
)
def get() -> ModelDefaults:
def get_exclude_unset_none() -> ModelDefaults:
return ModelDefaults(x=None, y="y")

View File

@@ -8,7 +8,7 @@ items = {"foo": "The Foo Wrestlers"}
@app.get("/items/{item_id}")
async def create_item(item_id: str):
async def read_item(item_id: str):
if item_id not in items:
raise HTTPException(
status_code=404,
@@ -19,7 +19,7 @@ async def create_item(item_id: str):
@app.get("/starlette-items/{item_id}")
async def create_item(item_id: str):
async def read_starlette_item(item_id: str):
if item_id not in items:
raise StarletteHTTPException(status_code=404, detail="Item not found")
return {"item": items[item_id]}
@@ -49,8 +49,8 @@ openapi_schema = {
},
},
},
"summary": "Create Item",
"operationId": "create_item_items__item_id__get",
"summary": "Read Item",
"operationId": "read_item_items__item_id__get",
"parameters": [
{
"required": True,
@@ -79,8 +79,8 @@ openapi_schema = {
},
},
},
"summary": "Create Item",
"operationId": "create_item_starlette_items__item_id__get",
"summary": "Read Starlette Item",
"operationId": "read_starlette_item_starlette_items__item_id__get",
"parameters": [
{
"required": True,

View File

@@ -18,9 +18,9 @@ def test_swagger_ui():
response = client.get("/docs")
assert response.status_code == 200, response.text
print(response.text)
assert f"ui.initOAuth" in response.text
assert f'"appName": "The Predendapp"' in response.text
assert f'"clientId": "the-foo-clients"' in response.text
assert "ui.initOAuth" in response.text
assert '"appName": "The Predendapp"' in response.text
assert '"clientId": "the-foo-clients"' in response.text
def test_response():

View File

@@ -126,6 +126,6 @@ def test_create_read():
assert data["text"] == note["text"]
assert data["completed"] == note["completed"]
assert "id" in data
response = client.get(f"/notes/")
response = client.get("/notes/")
assert response.status_code == 200, response.text
assert data in response.json()

View File

@@ -128,7 +128,7 @@ name_price_missing = {
body_missing = {
"detail": [
{"loc": ["body"], "msg": "field required", "type": "value_error.missing",}
{"loc": ["body"], "msg": "field required", "type": "value_error.missing"}
]
}

View File

@@ -3,15 +3,6 @@ from fastapi.testclient import TestClient
from docs_src.body_fields.tutorial001 import app
# TODO: remove when removing support for Pydantic < 1.0.0
try:
from pydantic import Field # noqa
except ImportError: # pragma: nocover
import pydantic
pydantic.Field = pydantic.Schema
client = TestClient(app)

View File

@@ -0,0 +1,22 @@
from fastapi.testclient import TestClient
from docs_src.websockets.tutorial003 import app
client = TestClient(app)
def test_websocket_handle_disconnection():
with client.websocket_connect("/ws/1234") as connection, client.websocket_connect(
"/ws/5678"
) as connection_two:
connection.send_text("Hello from 1234")
data1 = connection.receive_text()
assert data1 == "You wrote: Hello from 1234"
data2 = connection_two.receive_text()
client1_says = "Client #1234 says: Hello from 1234"
assert data2 == client1_says
data1 = connection.receive_text()
assert data1 == client1_says
connection_two.close()
data1 = connection.receive_text()
assert data1 == "Client #5678 left the chat"

View File

@@ -28,7 +28,7 @@ async def routerprefixindex(websocket: WebSocket):
@router.websocket("/router2")
async def routerindex(websocket: WebSocket):
async def routerindex2(websocket: WebSocket):
await websocket.accept()
await websocket.send_text("Hello, router!")
await websocket.close()