Compare commits

...

18 Commits

Author SHA1 Message Date
Sebastián Ramírez
c8eea09664 📝 Update release notes 2019-06-05 21:20:12 +04:00
Sebastián Ramírez
5700d65188 🔖 Release 0.28.0 2019-06-05 21:13:32 +04:00
Sebastián Ramírez
46178a5347 📝 Update release notes 2019-06-05 21:09:11 +04:00
Sebastián Ramírez
bff5dbbf5d Implement dependency value cache per request (#292)
*  Add dependency cache, with support for disabling it

*  Add tests for dependency cache

* 📝 Add docs about dependency value caching
2019-06-05 21:00:54 +04:00
Sebastián Ramírez
09cd7c47a1 Implement dependency overrides for testing (#291)
*  Implement dependency overrides for testing

*  Add docs source tests and extra tests for dependency overrides

* 📝 Add docs for testing dependencies with overrides
2019-06-05 15:43:18 +04:00
Sebastián Ramírez
e2fadcbc90 🔖 Release version 0.27.2 2019-06-03 22:03:24 +04:00
Sebastián Ramírez
b3bb29afa8 📝 Update relase notes 2019-06-03 22:01:09 +04:00
Sebastián Ramírez
c7db2ff858 🐛 Fix path and query parameters receiving dict as valid (#287)
* 🐛 Fix path and query parameters accepting dict

*  Add several tests to ensure invalid types are not accepted

* 📝 Document (to include tested source) using query params with list

* 🐛 Fix OpenAPI schema in query with list tutorial
2019-06-03 21:59:40 +04:00
Sebastián Ramírez
2a7ef5504a 🔖 Release 0.27.1 2019-06-03 18:44:03 +04:00
Sebastián Ramírez
27964c5ffd 📝 Update release notes 2019-06-01 10:00:26 +04:00
Sebastián Ramírez
d262f6e929 🐛 Fix HTTP Bearer security auto-error (#282) 2019-06-01 09:57:45 +04:00
Sebastián Ramírez
d61f5e4b55 📝 Update release notes 2019-05-30 19:43:32 +04:00
Sebastián Ramírez
3ed112e8a9 🐛 Fix type declaration of HTTPException (#279) 2019-05-30 19:43:02 +04:00
Sebastián Ramírez
9da626eb2c 🔖 Release version 0.27.0 2019-05-30 17:48:52 +04:00
Sebastián Ramírez
6f74c7327b 📝 Update release notes 2019-05-30 17:45:38 +04:00
dmontagu
360a2797c1 🐛 Fix docs link in oauth2-scopes.md (#275)
#274
2019-05-30 17:43:18 +04:00
Sebastián Ramírez
0552977cd6 📝 Update release notes 2019-05-30 17:41:40 +04:00
Sebastián Ramírez
bd407cc4ed Refactor param extraction using Pydantic Field (#278)
*  Refactor parameter dependency using Pydantic Field

* ⬆️ Upgrade required Pydantic version with latest Shape values

*  Add tutorials and code for using Enum and Optional

*  Add tests for tutorials with new types and extra cases

* ♻️ Format, clean, and add annotations to dependencies.utils

* 📝 Update tutorial for query parameters with list defaults

*  Add tests for query param with list default
2019-05-30 17:40:43 +04:00
38 changed files with 1595 additions and 167 deletions

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

View File

@@ -1,4 +1,40 @@
## Next release
## Latest changes
## 0.28.0
* Implement dependency cache per request.
* This avoids calling each dependency multiple times for the same request.
* This is useful while calling external services, performing costly computation, etc.
* This also means that if a dependency was declared as a *path operation decorator* dependency, possibly at the router level (with `.include_router()`) and then it is declared again in a specific *path operation*, the dependency will be called only once.
* The cache can be disabled per dependency declaration, using `use_cache=False` as in `Depends(your_dependency, use_cache=False)`.
* Updated docs at: [Using the same dependency multiple times](https://fastapi.tiangolo.com/tutorial/dependencies/sub-dependencies/#using-the-same-dependency-multiple-times).
* PR [#292](https://github.com/tiangolo/fastapi/pull/292).
* Implement dependency overrides for testing.
* This allows using overrides/mocks of dependencies during tests.
* New docs: [Testing Dependencies with Overrides](https://fastapi.tiangolo.com/tutorial/testing-dependencies/).
* PR [#291](https://github.com/tiangolo/fastapi/pull/291).
## 0.27.2
* Fix path and query parameters receiving `dict` as a valid type. It should be mapped to a body payload. PR [#287](https://github.com/tiangolo/fastapi/pull/287). Updated docs at: [Query parameter list / multiple values with defaults: Using `list`](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#using-list).
## 0.27.1
* Fix `auto_error=False` handling in `HTTPBearer` security scheme. Do not `raise` when there's an incorrect `Authorization` header if `auto_error=False`. PR [#282](https://github.com/tiangolo/fastapi/pull/282).
* Fix type declaration of `HTTPException`. PR [#279](https://github.com/tiangolo/fastapi/pull/279).
## 0.27.0
* Fix broken link in docs about OAuth 2.0 with scopes. PR [#275](https://github.com/tiangolo/fastapi/pull/275) by [@dmontagu](https://github.com/dmontagu).
* Refactor param extraction using Pydantic `Field`:
* Large refactor, improvement, and simplification of param extraction from *path operations*.
* Fix/add support for list *query parameters* with list defaults. New documentation: [Query parameter list / multiple values with defaults](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#query-parameter-list-multiple-values-with-defaults).
* Add support for enumerations in *path operation* parameters. New documentation: [Path Parameters: Predefined values](https://fastapi.tiangolo.com/tutorial/path-params/#predefined-values).
* Add support for type annotations using `Optional` as in `param: Optional[str] = None`. New documentation: [Optional type declarations](https://fastapi.tiangolo.com/tutorial/query-params/#optional-type-declarations).
* PR [#278](https://github.com/tiangolo/fastapi/pull/278).
## 0.26.0

View File

@@ -0,0 +1,55 @@
from fastapi import Depends, FastAPI
from starlette.testclient import TestClient
app = FastAPI()
async def common_parameters(q: str = None, skip: int = 0, limit: int = 100):
return {"q": q, "skip": skip, "limit": limit}
@app.get("/items/")
async def read_items(commons: dict = Depends(common_parameters)):
return {"message": "Hello Items!", "params": commons}
@app.get("/users/")
async def read_users(commons: dict = Depends(common_parameters)):
return {"message": "Hello Users!", "params": commons}
client = TestClient(app)
async def override_dependency(q: str = None):
return {"q": q, "skip": 5, "limit": 10}
app.dependency_overrides[common_parameters] = override_dependency
def test_override_in_items():
response = client.get("/items/")
assert response.status_code == 200
assert response.json() == {
"message": "Hello Items!",
"params": {"q": None, "skip": 5, "limit": 10},
}
def test_override_in_items_with_q():
response = client.get("/items/?q=foo")
assert response.status_code == 200
assert response.json() == {
"message": "Hello Items!",
"params": {"q": "foo", "skip": 5, "limit": 10},
}
def test_override_in_items_with_params():
response = client.get("/items/?q=foo&skip=100&limit=200")
assert response.status_code == 200
assert response.json() == {
"message": "Hello Items!",
"params": {"q": "foo", "skip": 5, "limit": 10},
}

View File

@@ -0,0 +1,21 @@
from enum import Enum
from fastapi import FastAPI
class ModelName(Enum):
alexnet = "alexnet"
resnet = "resnet"
lenet = "lenet"
app = FastAPI()
@app.get("/model/{model_name}")
async def get_model(model_name: ModelName):
if model_name == ModelName.alexnet:
return {"model_name": model_name, "message": "Deep Learning FTW!"}
if model_name.value == "lenet":
return {"model_name": model_name, "message": "LeCNN all the images"}
return {"model_name": model_name, "message": "Have some residuals"}

View File

@@ -0,0 +1,11 @@
from typing import Optional
from fastapi import FastAPI
app = FastAPI()
@app.get("/items/{item_id}")
async def read_user_item(item_id: str, limit: Optional[int] = None):
item = {"item_id": item_id, "limit": limit}
return item

View File

@@ -0,0 +1,11 @@
from typing import List
from fastapi import FastAPI, Query
app = FastAPI()
@app.get("/items/")
async def read_items(q: List[str] = Query(["foo", "bar"])):
query_items = {"q": q}
return query_items

View File

@@ -0,0 +1,9 @@
from fastapi import FastAPI, Query
app = FastAPI()
@app.get("/items/")
async def read_items(q: list = Query(None)):
query_items = {"q": q}
return query_items

View File

@@ -17,14 +17,12 @@ This is very useful when you need to:
All these, while minimizing code repetition.
## First Steps
Let's see a very simple example. It will be so simple that it is not very useful, for now.
But this way we can focus on how the **Dependency Injection** system works.
### Create a dependency, or "dependable"
Let's first focus on the dependency.
@@ -151,7 +149,6 @@ The simplicity of the dependency injection system makes **FastAPI** compatible w
* response data injection systems
* etc.
## Simple and Powerful
Although the hierarchical dependency injection system is very simple to define and use, it's still very powerful.

View File

@@ -11,6 +11,7 @@ You could create a first dependency ("dependable") like:
```Python hl_lines="6 7"
{!./src/dependencies/tutorial005.py!}
```
It declares an optional query parameter `q` as a `str`, and then it just returns it.
This is quite simple (not very useful), but will help us focus on how the sub-dependencies work.
@@ -43,6 +44,18 @@ Then we can use the dependency with:
But **FastAPI** will know that it has to solve `query_extractor` first, to pass the results of that to `query_or_cookie_extractor` while calling it.
## Using the same dependency multiple times
If one of your dependencies is declared multiple times for the same *path operation*, for example, multiple dependencies have a common sub-dependency, **FastAPI** will know to call that sub-dependency only once per request.
And it will save the returned value in a <abbr title="A utility/system to store computed/generated values, to re-use them instead of computing them again.">"cache"</abbr> and pass it to all the "dependants" that need it in that specific request, instead of calling the dependency multiple times for the same request.
In an advanced scenario where you know you need the dependency to be called at every step (possibly multiple times) in the same request instead of using the "cached" value, you can set the parameter `use_cache=False` when using `Depends`:
```Python hl_lines="1"
async def needy_dependency(fresh_value: str = Depends(get_value, use_cache=False)):
return {"fresh_value": fresh_value}
```
## Recap
@@ -54,7 +67,7 @@ But still, it is very powerful, and allows you to declare arbitrarily deeply nes
!!! tip
All this might not seem as useful with these simple examples.
But you will see how useful it is in the chapters about **security**.
And you will also see the amounts of code it will save you.

View File

@@ -35,7 +35,7 @@ If you run this example and open your browser at <a href="http://127.0.0.1:8000/
!!! check
Notice that the value your function received (and returned) is `3`, as a Python `int`, not a string `"3"`.
So, with that type declaration, **FastAPI** gives you automatic request <abbr title="converting the string that comes from an HTTP request into Python data">"parsing"</abbr>.
## Data validation
@@ -61,12 +61,11 @@ because the path parameter `item_id` had a value of `"foo"`, which is not an `in
The same error would appear if you provided a `float` instead of an int, as in: <a href="http://127.0.0.1:8000/items/4.2" target="_blank">http://127.0.0.1:8000/items/4.2</a>
!!! check
So, with the same Python type declaration, **FastAPI** gives you data validation.
Notice that the error also clearly states exactly the point where the validation didn't pass.
Notice that the error also clearly states exactly the point where the validation didn't pass.
This is incredibly helpful while developing and debugging code that interacts with your API.
## Documentation
@@ -96,8 +95,7 @@ All the data validation is performed under the hood by <a href="https://pydantic
You can use the same type declarations with `str`, `float`, `bool` and many other complex data types.
These are explored in the next chapters of the tutorial.
Several of these are explored in the next chapters of the tutorial.
## Order matters
@@ -115,6 +113,73 @@ Because path operations are evaluated in order, you need to make sure that the p
Otherwise, the path for `/users/{user_id}` would match also for `/users/me`, "thinking" that it's receiving a parameter `user_id` with a value of `"me"`.
## Predefined values
If you have a *path operation* that receives a *path parameter*, but you want the possible valid *path parameter* values to be predefined, you can use a standard Python <abbr title="Enumeration">`Enum`</abbr>.
### Create an `Enum` class
Import `Enum` and create a sub-class that inherits from it.
And create class attributes with fixed values, those fixed values will be the available valid values:
```Python hl_lines="1 6 7 8 9"
{!./src/path_params/tutorial005.py!}
```
!!! info
<a href="https://docs.python.org/3/library/enum.html" target="_blank">Enumerations (or enums) are available in Python</a> since version 3.4.
!!! tip
If you are wondering, "AlexNet", "ResNet", and "LeNet" are just names of Machine Learning <abbr title="Technically, Deep Learning model architectures">models</abbr>.
### Declare a *path parameter*
Then create a *path parameter* with a type annotation using the enum class you created (`ModelName`):
```Python hl_lines="16"
{!./src/path_params/tutorial005.py!}
```
### Check the docs
Because the available values for the *path parameter* are specified, the interactive docs can show them nicely:
<img src="/img/tutorial/path-params/image03.png">
### Working with Python *enumerations*
The value of the *path parameter* will be an *enumeration member*.
#### Compare *enumeration members*
You can compare it with the *enumeration member* in your created enum `ModelName`:
```Python hl_lines="17"
{!./src/path_params/tutorial005.py!}
```
#### Get the *enumeration value*
You can get the actual value (a `str` in this case) using `model_name.value`, or in general, `your_enum_member.value`:
```Python hl_lines="19"
{!./src/path_params/tutorial005.py!}
```
!!! tip
You could also access the value `"lenet"` with `ModelName.lenet.value`.
#### Return *enumeration members*
You can return *enum members* from your *path operation*, even nested in a JSON body (e.g. a `dict`).
They will be converted to their corresponding values before returning them to the client:
```Python hl_lines="18 20 21"
{!./src/path_params/tutorial005.py!}
```
## Path parameters containing paths
Let's say you have a *path operation* with a path `/files/{file_path}`.

View File

@@ -12,7 +12,6 @@ The query parameter `q` is of type `str`, and by default is `None`, so it is opt
We are going to enforce that even though `q` is optional, whenever it is provided, it **doesn't exceed a length of 50 characters**.
### Import `Query`
To achieve that, first import `Query` from `fastapi`:
@@ -29,7 +28,7 @@ And now use it as the default value of your parameter, setting the parameter `ma
{!./src/query_params_str_validations/tutorial002.py!}
```
As we have to replace the default value `None` with `Query(None)`, the first parameter to `Query` serves the same purpose of defining that default value.
As we have to replace the default value `None` with `Query(None)`, the first parameter to `Query` serves the same purpose of defining that default value.
So:
@@ -41,7 +40,7 @@ q: str = Query(None)
```Python
q: str = None
```
```
But it declares it explicitly as being a query parameter.
@@ -53,7 +52,6 @@ q: str = Query(None, max_length=50)
This will validate the data, show a clear error when the data is not valid, and document the parameter in the OpenAPI schema path operation.
## Add more validations
You can also add a parameter `min_length`:
@@ -119,7 +117,7 @@ So, when you need to declare a value as required while using `Query`, you can us
{!./src/query_params_str_validations/tutorial006.py!}
```
!!! info
!!! info
If you hadn't seen that `...` before: it is a a special single value, it is <a href="https://docs.python.org/3/library/constants.html#Ellipsis" target="_blank">part of Python and is called "Ellipsis"</a>.
This will let **FastAPI** know that this parameter is required.
@@ -156,11 +154,48 @@ So, the response to that URL would be:
!!! tip
To declare a query parameter with a type of `list`, like in the example above, you need to explicitly use `Query`, otherwise it would be interpreted as a request body.
The interactive API docs will update accordingly, to allow multiple values:
<img src="/img/tutorial/query-params-str-validations/image02.png">
### Query parameter list / multiple values with defaults
And you can also define a default `list` of values if none are provided:
```Python hl_lines="9"
{!./src/query_params_str_validations/tutorial012.py!}
```
If you go to:
```
http://localhost:8000/items/
```
the default of `q` will be: `["foo", "bar"]` and your response will be:
```JSON
{
"q": [
"foo",
"bar"
]
}
```
#### Using `list`
You can also use `list` directly instead of `List[str]`:
```Python hl_lines="7"
{!./src/query_params_str_validations/tutorial013.py!}
```
!!! note
Have in mind that in this case, FastAPI won't check the contents of the list.
For example, `List[int]` would check (and document) that the contents of the list are integers. But `list` alone wouldn't.
## Declare more metadata
You can add more information about the parameter.

View File

@@ -186,3 +186,39 @@ In this case, there are 3 query parameters:
* `needy`, a required `str`.
* `skip`, an `int` with a default value of `0`.
* `limit`, an optional `int`.
!!! tip
You could also use `Enum`s <a href="https://fastapi.tiangolo.com/tutorial/path-params/#predefined-values" target="_blank">the same way as with *path parameters*</a>.
## Optional type declarations
!!! warning
This might be an advanced use case.
You might want to skip it.
If you are using `mypy` it could complain with type declarations like:
```Python
limit: int = None
```
With an error like:
```
Incompatible types in assignment (expression has type "None", variable has type "int")
```
In those cases you can use `Optional` to tell `mypy` that the value could be `None`, like:
```Python
from typing import Optional
limit: Optional[int] = None
```
In a *path operation* that could look like:
```Python hl_lines="9"
{!./src/query_params/tutorial007.py!}
```

View File

@@ -247,4 +247,4 @@ The most secure is the code flow, but is more complex to implement as it require
## `Security` in decorator `dependencies`
The same way you can define a `list` of <a href="https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-in-decorator/" target="_blank">`Depends` in the decorator's `dependencies` parameter</a>, you could also use `Security` with `scopes` there.
The same way you can define a `list` of <a href="https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-in-path-operation-decorators/" target="_blank">`Depends` in the decorator's `dependencies` parameter</a>, you could also use `Security` with `scopes` there.

View File

@@ -0,0 +1,59 @@
## Overriding dependencies during testing
There are some scenarios where you might want to override a dependency during testing.
You don't want the original dependency to run (nor any of the sub-dependencies it might have).
Instead, you want to provide a different dependency that will be used only during tests (possibly only some specific tests), and will provide a value that can be used where the value of the original dependency was used.
### Use cases: external service
An example could be that you have an external authentication provider that you need to call.
You send it a token and it returns an authenticated user.
This provider might be charging you per request, and calling it might take some extra time than if you had a fixed mock user for tests.
You probably want to test the external provider once, but not necessarily call it for every test that runs.
In this case, you can override the dependency that calls that provider, and use a custom dependency that returns a mock user, only for your tests.
### Use case: testing database
Other example could be that you are using a specific database only for testing.
Your normal dependency would return a database session.
But then, after each test, you could want to rollback all the operations or remove data.
Or you could want to alter the data before the tests run, etc.
In this case, you could use a dependency override to return your *custom* database session instead of the one that would be used normally.
### Use the `app.dependency_overrides` attribute
For these cases, your **FastAPI** application has an attribute `app.dependency_overrides`, it is a simple `dict`.
To override a dependency for testing, you put as a key the original dependency (a function), and as the value, your dependency override (another function).
And then **FastAPI** will call that override instead of the original dependency.
```Python hl_lines="24 25 28"
{!./src/dependency_testing/tutorial001.py!}
```
!!! tip
You can set a dependency override for a dependency used anywhere in your **FastAPI** application.
The original dependency could be used in a *path operation function*, a *path operation decorator* (when you don't use the return value), a `.include_router()` call, etc.
FastAPI will still be able to override it.
Then you can reset your overrides (remove them) by setting `app.dependency_overrides` to be an empty `dict`:
```Python
app.dependency_overrides = {}
```
!!! tip
If you want to override a dependency only during some tests, you can set the override at the beginning of the test (inside the test function) and reset it at the end (at the end of the test function).

View File

@@ -1,6 +1,6 @@
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
__version__ = "0.26.0"
__version__ = "0.28.0"
from starlette.background import BackgroundTasks

View File

@@ -38,7 +38,9 @@ class FastAPI(Starlette):
**extra: Dict[str, Any],
) -> None:
self._debug = debug
self.router: routing.APIRouter = routing.APIRouter(routes)
self.router: routing.APIRouter = routing.APIRouter(
routes, dependency_overrides_provider=self
)
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
self.error_middleware = ServerErrorMiddleware(
self.exception_middleware, debug=debug
@@ -53,6 +55,7 @@ class FastAPI(Starlette):
self.redoc_url = redoc_url
self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url
self.extra = extra
self.dependency_overrides: Dict[Callable, Callable] = {}
self.openapi_version = "3.0.2"

View File

@@ -30,6 +30,8 @@ class Dependant:
background_tasks_param_name: str = None,
security_scopes_param_name: str = None,
security_scopes: List[str] = None,
use_cache: bool = True,
path: str = None,
) -> None:
self.path_params = path_params or []
self.query_params = query_params or []
@@ -45,3 +47,8 @@ class Dependant:
self.security_scopes_param_name = security_scopes_param_name
self.name = name
self.call = call
self.use_cache = use_cache
# Store the path to be able to re-generate a dependable from it in overrides
self.path = path
# Save the cache key at creation to optimize performance
self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or []))))

View File

@@ -1,8 +1,6 @@
import asyncio
import inspect
from copy import deepcopy
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import (
Any,
Callable,
@@ -14,8 +12,8 @@ from typing import (
Tuple,
Type,
Union,
cast,
)
from uuid import UUID
from fastapi import params
from fastapi.dependencies.models import Dependant, SecurityRequirement
@@ -23,7 +21,7 @@ from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.utils import get_path_param_names
from pydantic import BaseConfig, Schema, create_model
from pydantic import BaseConfig, BaseModel, Schema, create_model
from pydantic.error_wrappers import ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import Field, Required, Shape
@@ -35,22 +33,21 @@ from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
from starlette.requests import Request
from starlette.websockets import WebSocket
param_supported_types = (
str,
int,
float,
bool,
UUID,
date,
datetime,
time,
timedelta,
Decimal,
)
sequence_shapes = {Shape.LIST, Shape.SET, Shape.TUPLE}
sequence_shapes = {
Shape.LIST,
Shape.SET,
Shape.TUPLE,
Shape.SEQUENCE,
Shape.TUPLE_ELLIPS,
}
sequence_types = (list, set, tuple)
sequence_shape_to_type = {Shape.LIST: list, Shape.SET: set, Shape.TUPLE: tuple}
sequence_shape_to_type = {
Shape.LIST: list,
Shape.SET: set,
Shape.TUPLE: tuple,
Shape.SEQUENCE: list,
Shape.TUPLE_ELLIPS: list,
}
def get_param_sub_dependant(
@@ -98,7 +95,11 @@ def get_sub_dependant(
security_scheme=dependency, scopes=use_scopes
)
sub_dependant = get_dependant(
path=path, call=dependency, name=name, security_scopes=security_scopes
path=path,
call=dependency,
name=name,
security_scopes=security_scopes,
use_cache=depends.use_cache,
)
if security_requirement:
sub_dependant.security_requirements.append(security_requirement)
@@ -114,6 +115,8 @@ def get_flat_dependant(dependant: Dependant) -> Dependant:
cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(),
security_schemes=dependant.security_requirements.copy(),
use_cache=dependant.use_cache,
path=dependant.path,
)
for sub_dependant in dependant.dependencies:
flat_sub = get_flat_dependant(sub_dependant)
@@ -126,90 +129,113 @@ def get_flat_dependant(dependant: Dependant) -> Dependant:
return flat_dependant
def is_scalar_field(field: Field) -> bool:
return (
field.shape == Shape.SINGLETON
and not lenient_issubclass(field.type_, BaseModel)
and not lenient_issubclass(field.type_, sequence_types + (dict,))
and not isinstance(field.schema, params.Body)
)
def is_scalar_sequence_field(field: Field) -> bool:
if (field.shape in sequence_shapes) and not lenient_issubclass(
field.type_, BaseModel
):
if field.sub_fields is not None:
for sub_field in field.sub_fields:
if not is_scalar_field(sub_field):
return False
return True
if lenient_issubclass(field.type_, sequence_types):
return True
return False
def get_dependant(
*, path: str, call: Callable, name: str = None, security_scopes: List[str] = None
*,
path: str,
call: Callable,
name: str = None,
security_scopes: List[str] = None,
use_cache: bool = True,
) -> Dependant:
path_param_names = get_path_param_names(path)
endpoint_signature = inspect.signature(call)
signature_params = endpoint_signature.parameters
dependant = Dependant(call=call, name=name)
for param_name in signature_params:
param = signature_params[param_name]
dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
for param_name, param in signature_params.items():
if isinstance(param.default, params.Depends):
sub_dependant = get_param_sub_dependant(
param=param, path=path, security_scopes=security_scopes
)
dependant.dependencies.append(sub_dependant)
for param_name in signature_params:
param = signature_params[param_name]
if (
(param.default == param.empty) or isinstance(param.default, params.Path)
) and (param_name in path_param_names):
assert (
lenient_issubclass(param.annotation, param_supported_types)
or param.annotation == param.empty
for param_name, param in signature_params.items():
if isinstance(param.default, params.Depends):
continue
if add_non_field_param_to_dependency(param=param, dependant=dependant):
continue
param_field = get_param_field(param=param, default_schema=params.Query)
if param_name in path_param_names:
assert param.default == param.empty or isinstance(
param.default, params.Path
), "Path params must have no defaults or use Path(...)"
assert is_scalar_field(
field=param_field
), f"Path params must be of one of the supported types"
add_param_to_fields(
param_field = get_param_field(
param=param,
dependant=dependant,
default_schema=params.Path,
force_type=params.ParamTypes.path,
)
elif (
param.default == param.empty
or param.default is None
or isinstance(param.default, param_supported_types)
) and (
param.annotation == param.empty
or lenient_issubclass(param.annotation, param_supported_types)
):
add_param_to_fields(
param=param, dependant=dependant, default_schema=params.Query
)
elif isinstance(param.default, params.Param):
if param.annotation != param.empty:
origin = getattr(param.annotation, "__origin__", None)
param_all_types = param_supported_types + (list, tuple, set)
if isinstance(param.default, (params.Query, params.Header)):
assert lenient_issubclass(
param.annotation, param_all_types
) or lenient_issubclass(
origin, param_all_types
), f"Parameters for Query and Header must be of type str, int, float, bool, list, tuple or set: {param}"
else:
assert lenient_issubclass(
param.annotation, param_supported_types
), f"Parameters for Path and Cookies must be of type str, int, float, bool: {param}"
add_param_to_fields(
param=param, dependant=dependant, default_schema=params.Query
)
elif lenient_issubclass(param.annotation, Request):
dependant.request_param_name = param_name
elif lenient_issubclass(param.annotation, WebSocket):
dependant.websocket_param_name = param_name
elif lenient_issubclass(param.annotation, BackgroundTasks):
dependant.background_tasks_param_name = param_name
elif lenient_issubclass(param.annotation, SecurityScopes):
dependant.security_scopes_param_name = param_name
elif not isinstance(param.default, params.Depends):
add_param_to_body_fields(param=param, dependant=dependant)
add_param_to_fields(field=param_field, dependant=dependant)
elif is_scalar_field(field=param_field):
add_param_to_fields(field=param_field, dependant=dependant)
elif isinstance(
param.default, (params.Query, params.Header)
) and is_scalar_sequence_field(param_field):
add_param_to_fields(field=param_field, dependant=dependant)
else:
assert isinstance(
param_field.schema, params.Body
), f"Param: {param_field.name} can only be a request body, using Body(...)"
dependant.body_params.append(param_field)
return dependant
def add_param_to_fields(
def add_non_field_param_to_dependency(
*, param: inspect.Parameter, dependant: Dependant
) -> Optional[bool]:
if lenient_issubclass(param.annotation, Request):
dependant.request_param_name = param.name
return True
elif lenient_issubclass(param.annotation, WebSocket):
dependant.websocket_param_name = param.name
return True
elif lenient_issubclass(param.annotation, BackgroundTasks):
dependant.background_tasks_param_name = param.name
return True
elif lenient_issubclass(param.annotation, SecurityScopes):
dependant.security_scopes_param_name = param.name
return True
return None
def get_param_field(
*,
param: inspect.Parameter,
dependant: Dependant,
default_schema: Type[Schema] = params.Param,
default_schema: Type[params.Param] = params.Param,
force_type: params.ParamTypes = None,
) -> None:
) -> Field:
default_value = Required
had_schema = False
if not param.default == param.empty:
default_value = param.default
if isinstance(default_value, params.Param):
if isinstance(default_value, Schema):
had_schema = True
schema = default_value
default_value = schema.default
if getattr(schema, "in_", None) is None:
if isinstance(schema, params.Param) and getattr(schema, "in_", None) is None:
schema.in_ = default_schema.in_
if force_type:
schema.in_ = force_type
@@ -234,43 +260,26 @@ def add_param_to_fields(
class_validators={},
schema=schema,
)
if schema.in_ == params.ParamTypes.path:
if not had_schema and not is_scalar_field(field=field):
field.schema = params.Body(schema.default)
return field
def add_param_to_fields(*, field: Field, dependant: Dependant) -> None:
field.schema = cast(params.Param, field.schema)
if field.schema.in_ == params.ParamTypes.path:
dependant.path_params.append(field)
elif schema.in_ == params.ParamTypes.query:
elif field.schema.in_ == params.ParamTypes.query:
dependant.query_params.append(field)
elif schema.in_ == params.ParamTypes.header:
elif field.schema.in_ == params.ParamTypes.header:
dependant.header_params.append(field)
else:
assert (
schema.in_ == params.ParamTypes.cookie
), f"non-body parameters must be in path, query, header or cookie: {param.name}"
field.schema.in_ == params.ParamTypes.cookie
), f"non-body parameters must be in path, query, header or cookie: {field.name}"
dependant.cookie_params.append(field)
def add_param_to_body_fields(*, param: inspect.Parameter, dependant: Dependant) -> None:
default_value = Required
if not param.default == param.empty:
default_value = param.default
if isinstance(default_value, Schema):
schema = default_value
default_value = schema.default
else:
schema = Schema(default_value)
required = default_value == Required
annotation = get_annotation_from_schema(param.annotation, schema)
field = Field(
name=param.name,
type_=annotation,
default=None if required else default_value,
alias=schema.alias or param.name,
required=required,
model_config=BaseConfig,
class_validators={},
schema=schema,
)
dependant.body_params.append(field)
def is_coroutine_callable(call: Callable) -> bool:
if inspect.isfunction(call):
return asyncio.iscoroutinefunction(call)
@@ -286,26 +295,63 @@ async def solve_dependencies(
dependant: Dependant,
body: Dict[str, Any] = None,
background_tasks: BackgroundTasks = None,
) -> Tuple[Dict[str, Any], List[ErrorWrapper], Optional[BackgroundTasks]]:
dependency_overrides_provider: Any = None,
dependency_cache: Dict[Tuple[Callable, Tuple[str]], Any] = None,
) -> Tuple[
Dict[str, Any],
List[ErrorWrapper],
Optional[BackgroundTasks],
Dict[Tuple[Callable, Tuple[str]], Any],
]:
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []
dependency_cache = dependency_cache or {}
sub_dependant: Dependant
for sub_dependant in dependant.dependencies:
sub_values, sub_errors, background_tasks = await solve_dependencies(
sub_dependant.call = cast(Callable, sub_dependant.call)
sub_dependant.cache_key = cast(
Tuple[Callable, Tuple[str]], sub_dependant.cache_key
)
call = sub_dependant.call
use_sub_dependant = sub_dependant
if (
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
):
original_call = sub_dependant.call
call = getattr(
dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call)
use_path: str = sub_dependant.path # type: ignore
use_sub_dependant = get_dependant(
path=use_path,
call=call,
name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes,
)
sub_values, sub_errors, background_tasks, sub_dependency_cache = await solve_dependencies(
request=request,
dependant=sub_dependant,
dependant=use_sub_dependant,
body=body,
background_tasks=background_tasks,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
)
dependency_cache.update(sub_dependency_cache)
if sub_errors:
errors.extend(sub_errors)
continue
assert sub_dependant.call is not None, "sub_dependant.call must be a function"
if is_coroutine_callable(sub_dependant.call):
solved = await sub_dependant.call(**sub_values)
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif is_coroutine_callable(call):
solved = await call(**sub_values)
else:
solved = await run_in_threadpool(sub_dependant.call, **sub_values)
solved = await run_in_threadpool(call, **sub_values)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:
dependency_cache[sub_dependant.cache_key] = solved
path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params
)
@@ -341,7 +387,7 @@ async def solve_dependencies(
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes
)
return values, errors, background_tasks
return values, errors, background_tasks, dependency_cache
def request_params_to_args(
@@ -351,10 +397,10 @@ def request_params_to_args(
values = {}
errors = []
for field in required_params:
if field.shape in sequence_shapes and isinstance(
if is_scalar_sequence_field(field) and isinstance(
received_params, (QueryParams, Headers)
):
value = received_params.getlist(field.alias)
value = received_params.getlist(field.alias) or field.default
else:
value = received_params.get(field.alias)
schema: params.Param = field.schema

View File

@@ -1,10 +1,12 @@
from typing import Any
from pydantic import ValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
class HTTPException(StarletteHTTPException):
def __init__(
self, status_code: int, detail: str = None, headers: dict = None
self, status_code: int, detail: Any = None, headers: dict = None
) -> None:
super().__init__(status_code=status_code, detail=detail)
self.headers = headers

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, cast
from fastapi import routing
from fastapi.dependencies.models import Dependant
@@ -9,7 +9,7 @@ from fastapi.openapi.models import OpenAPI
from fastapi.params import Body, Param
from fastapi.utils import get_flat_models_from_routes, get_model_definitions
from pydantic.fields import Field
from pydantic.schema import Schema, field_schema, get_model_name_map
from pydantic.schema import field_schema, get_model_name_map
from pydantic.utils import lenient_issubclass
from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
@@ -97,12 +97,8 @@ def get_openapi_operation_request_body(
body_schema, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
schema: Schema = body_field.schema
if isinstance(schema, Body):
request_media_type = schema.media_type
else:
# Includes not declared media types (Schema)
request_media_type = "application/json"
body_field.schema = cast(Body, body_field.schema)
request_media_type = body_field.schema.media_type
required = body_field.required
request_body_oai: Dict[str, Any] = {}
if required:

View File

@@ -238,11 +238,13 @@ def File( # noqa: N802
)
def Depends(dependency: Callable = None) -> Any: # noqa: N802
return params.Depends(dependency=dependency)
def Depends( # noqa: N802
dependency: Callable = None, *, use_cache: bool = True
) -> Any:
return params.Depends(dependency=dependency, use_cache=use_cache)
def Security( # noqa: N802
dependency: Callable = None, scopes: Sequence[str] = None
dependency: Callable = None, *, scopes: Sequence[str] = None, use_cache: bool = True
) -> Any:
return params.Security(dependency=dependency, scopes=scopes)
return params.Security(dependency=dependency, scopes=scopes, use_cache=use_cache)

View File

@@ -308,11 +308,18 @@ class File(Form):
class Depends:
def __init__(self, dependency: Callable = None):
def __init__(self, dependency: Callable = None, *, use_cache: bool = True):
self.dependency = dependency
self.use_cache = use_cache
class Security(Depends):
def __init__(self, dependency: Callable = None, scopes: Sequence[str] = None):
def __init__(
self,
dependency: Callable = None,
*,
scopes: Sequence[str] = None,
use_cache: bool = True,
):
super().__init__(dependency=dependency, use_cache=use_cache)
self.scopes = scopes or []
super().__init__(dependency=dependency)

View File

@@ -30,6 +30,7 @@ from starlette.routing import (
websocket_session,
)
from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.types import ASGIApp
from starlette.websockets import WebSocket
@@ -80,6 +81,7 @@ def get_app(
response_model_exclude: Set[str] = set(),
response_model_by_alias: bool = True,
response_model_skip_defaults: bool = False,
dependency_overrides_provider: Any = None,
) -> Callable:
assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = asyncio.iscoroutinefunction(dependant.call)
@@ -100,8 +102,11 @@ def get_app(
raise HTTPException(
status_code=400, detail="There was an error parsing the body"
) from e
values, errors, background_tasks = await solve_dependencies(
request=request, dependant=dependant, body=body
values, errors, background_tasks, _ = await solve_dependencies(
request=request,
dependant=dependant,
body=body,
dependency_overrides_provider=dependency_overrides_provider,
)
if errors:
raise RequestValidationError(errors)
@@ -132,10 +137,14 @@ def get_app(
return app
def get_websocket_app(dependant: Dependant) -> Callable:
def get_websocket_app(
dependant: Dependant, dependency_overrides_provider: Any = None
) -> Callable:
async def app(websocket: WebSocket) -> None:
values, errors, _ = await solve_dependencies(
request=websocket, dependant=dependant
values, errors, _, _2 = await solve_dependencies(
request=websocket,
dependant=dependant,
dependency_overrides_provider=dependency_overrides_provider,
)
if errors:
await websocket.close(code=WS_1008_POLICY_VIOLATION)
@@ -147,12 +156,24 @@ def get_websocket_app(dependant: Dependant) -> Callable:
class APIWebSocketRoute(routing.WebSocketRoute):
def __init__(self, path: str, endpoint: Callable, *, name: str = None) -> None:
def __init__(
self,
path: str,
endpoint: Callable,
*,
name: str = None,
dependency_overrides_provider: Any = None,
) -> None:
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.dependant = get_dependant(path=path, call=self.endpoint)
self.app = websocket_session(get_websocket_app(dependant=self.dependant))
self.app = websocket_session(
get_websocket_app(
dependant=self.dependant,
dependency_overrides_provider=dependency_overrides_provider,
)
)
regex = "^" + path + "$"
regex = re.sub("{([a-zA-Z_][a-zA-Z0-9_]*)}", r"(?P<\1>[^/]+)", regex)
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
@@ -182,6 +203,7 @@ class APIRoute(routing.Route):
response_model_skip_defaults: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = JSONResponse,
dependency_overrides_provider: Any = None,
) -> None:
assert path.startswith("/"), "Routed paths must always start with '/'"
self.path = path
@@ -257,6 +279,7 @@ class APIRoute(routing.Route):
get_parameterless_sub_dependant(depends=depends, path=self.path_format),
)
self.body_field = get_body_field(dependant=self.dependant, name=self.name)
self.dependency_overrides_provider = dependency_overrides_provider
self.app = request_response(
get_app(
dependant=self.dependant,
@@ -268,11 +291,24 @@ class APIRoute(routing.Route):
response_model_exclude=self.response_model_exclude,
response_model_by_alias=self.response_model_by_alias,
response_model_skip_defaults=self.response_model_skip_defaults,
dependency_overrides_provider=self.dependency_overrides_provider,
)
)
class APIRouter(routing.Router):
def __init__(
self,
routes: List[routing.BaseRoute] = None,
redirect_slashes: bool = True,
default: ASGIApp = None,
dependency_overrides_provider: Any = None,
) -> None:
super().__init__(
routes=routes, redirect_slashes=redirect_slashes, default=default
)
self.dependency_overrides_provider = dependency_overrides_provider
def add_api_route(
self,
path: str,
@@ -318,6 +354,7 @@ class APIRouter(routing.Router):
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
dependency_overrides_provider=self.dependency_overrides_provider,
)
self.routes.append(route)

View File

@@ -112,10 +112,13 @@ class HTTPBearer(HTTPBase):
else:
return None
if scheme.lower() != "bearer":
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid authentication credentials",
)
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid authentication credentials",
)
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)

View File

@@ -81,6 +81,7 @@ nav:
- WebSockets: 'tutorial/websockets.md'
- 'Events: startup - shutdown': 'tutorial/events.md'
- Testing: 'tutorial/testing.md'
- Testing Dependencies with Overrides: 'tutorial/testing-dependencies.md'
- Debugging: 'tutorial/debugging.md'
- Extending OpenAPI: 'tutorial/extending-openapi.md'
- Concurrency and async / await: 'async.md'

View File

@@ -20,7 +20,7 @@ classifiers = [
]
requires = [
"starlette >=0.11.1,<=0.12.0",
"pydantic >=0.17,<=0.26.0"
"pydantic >=0.26,<=0.26.0"
]
description-file = "README.md"
requires-python = ">=3.6"

View File

@@ -0,0 +1,68 @@
from fastapi import Depends, FastAPI
from starlette.testclient import TestClient
app = FastAPI()
counter_holder = {"counter": 0}
async def dep_counter():
counter_holder["counter"] += 1
return counter_holder["counter"]
async def super_dep(count: int = Depends(dep_counter)):
return count
@app.get("/counter/")
async def get_counter(count: int = Depends(dep_counter)):
return {"counter": count}
@app.get("/sub-counter/")
async def get_sub_counter(
subcount: int = Depends(super_dep), count: int = Depends(dep_counter)
):
return {"counter": count, "subcounter": subcount}
@app.get("/sub-counter-no-cache/")
async def get_sub_counter_no_cache(
subcount: int = Depends(super_dep),
count: int = Depends(dep_counter, use_cache=False),
):
return {"counter": count, "subcounter": subcount}
client = TestClient(app)
def test_normal_counter():
counter_holder["counter"] = 0
response = client.get("/counter/")
assert response.status_code == 200
assert response.json() == {"counter": 1}
response = client.get("/counter/")
assert response.status_code == 200
assert response.json() == {"counter": 2}
def test_sub_counter():
counter_holder["counter"] = 0
response = client.get("/sub-counter/")
assert response.status_code == 200
assert response.json() == {"counter": 1, "subcounter": 1}
response = client.get("/sub-counter/")
assert response.status_code == 200
assert response.json() == {"counter": 2, "subcounter": 2}
def test_sub_counter_no_cache():
counter_holder["counter"] = 0
response = client.get("/sub-counter-no-cache/")
assert response.status_code == 200
assert response.json() == {"counter": 2, "subcounter": 1}
response = client.get("/sub-counter-no-cache/")
assert response.status_code == 200
assert response.json() == {"counter": 4, "subcounter": 3}

View File

@@ -0,0 +1,313 @@
import pytest
from fastapi import APIRouter, Depends, FastAPI
from starlette.testclient import TestClient
app = FastAPI()
router = APIRouter()
async def common_parameters(q: str, skip: int = 0, limit: int = 100):
return {"q": q, "skip": skip, "limit": limit}
@app.get("/main-depends/")
async def main_depends(commons: dict = Depends(common_parameters)):
return {"in": "main-depends", "params": commons}
@app.get("/decorator-depends/", dependencies=[Depends(common_parameters)])
async def decorator_depends():
return {"in": "decorator-depends"}
@router.get("/router-depends/")
async def router_depends(commons: dict = Depends(common_parameters)):
return {"in": "router-depends", "params": commons}
@router.get("/router-decorator-depends/", dependencies=[Depends(common_parameters)])
async def router_decorator_depends():
return {"in": "router-decorator-depends"}
app.include_router(router)
client = TestClient(app)
async def overrider_dependency_simple(q: str = None):
return {"q": q, "skip": 5, "limit": 10}
async def overrider_sub_dependency(k: str):
return {"k": k}
async def overrider_dependency_with_sub(msg: dict = Depends(overrider_sub_dependency)):
return msg
@pytest.mark.parametrize(
"url,status_code,expected",
[
(
"/main-depends/",
422,
{
"detail": [
{
"loc": ["query", "q"],
"msg": "field required",
"type": "value_error.missing",
}
]
},
),
(
"/main-depends/?q=foo",
200,
{"in": "main-depends", "params": {"q": "foo", "skip": 0, "limit": 100}},
),
(
"/main-depends/?q=foo&skip=100&limit=200",
200,
{"in": "main-depends", "params": {"q": "foo", "skip": 100, "limit": 200}},
),
(
"/decorator-depends/",
422,
{
"detail": [
{
"loc": ["query", "q"],
"msg": "field required",
"type": "value_error.missing",
}
]
},
),
("/decorator-depends/?q=foo", 200, {"in": "decorator-depends"}),
(
"/decorator-depends/?q=foo&skip=100&limit=200",
200,
{"in": "decorator-depends"},
),
(
"/router-depends/",
422,
{
"detail": [
{
"loc": ["query", "q"],
"msg": "field required",
"type": "value_error.missing",
}
]
},
),
(
"/router-depends/?q=foo",
200,
{"in": "router-depends", "params": {"q": "foo", "skip": 0, "limit": 100}},
),
(
"/router-depends/?q=foo&skip=100&limit=200",
200,
{"in": "router-depends", "params": {"q": "foo", "skip": 100, "limit": 200}},
),
(
"/router-decorator-depends/",
422,
{
"detail": [
{
"loc": ["query", "q"],
"msg": "field required",
"type": "value_error.missing",
}
]
},
),
("/router-decorator-depends/?q=foo", 200, {"in": "router-decorator-depends"}),
(
"/router-decorator-depends/?q=foo&skip=100&limit=200",
200,
{"in": "router-decorator-depends"},
),
],
)
def test_normal_app(url, status_code, expected):
response = client.get(url)
assert response.status_code == status_code
assert response.json() == expected
@pytest.mark.parametrize(
"url,status_code,expected",
[
(
"/main-depends/",
200,
{"in": "main-depends", "params": {"q": None, "skip": 5, "limit": 10}},
),
(
"/main-depends/?q=foo",
200,
{"in": "main-depends", "params": {"q": "foo", "skip": 5, "limit": 10}},
),
(
"/main-depends/?q=foo&skip=100&limit=200",
200,
{"in": "main-depends", "params": {"q": "foo", "skip": 5, "limit": 10}},
),
("/decorator-depends/", 200, {"in": "decorator-depends"}),
(
"/router-depends/",
200,
{"in": "router-depends", "params": {"q": None, "skip": 5, "limit": 10}},
),
(
"/router-depends/?q=foo",
200,
{"in": "router-depends", "params": {"q": "foo", "skip": 5, "limit": 10}},
),
(
"/router-depends/?q=foo&skip=100&limit=200",
200,
{"in": "router-depends", "params": {"q": "foo", "skip": 5, "limit": 10}},
),
("/router-decorator-depends/", 200, {"in": "router-decorator-depends"}),
],
)
def test_override_simple(url, status_code, expected):
app.dependency_overrides[common_parameters] = overrider_dependency_simple
response = client.get(url)
assert response.status_code == status_code
assert response.json() == expected
app.dependency_overrides = {}
@pytest.mark.parametrize(
"url,status_code,expected",
[
(
"/main-depends/",
422,
{
"detail": [
{
"loc": ["query", "k"],
"msg": "field required",
"type": "value_error.missing",
}
]
},
),
(
"/main-depends/?q=foo",
422,
{
"detail": [
{
"loc": ["query", "k"],
"msg": "field required",
"type": "value_error.missing",
}
]
},
),
("/main-depends/?k=bar", 200, {"in": "main-depends", "params": {"k": "bar"}}),
(
"/decorator-depends/",
422,
{
"detail": [
{
"loc": ["query", "k"],
"msg": "field required",
"type": "value_error.missing",
}
]
},
),
(
"/decorator-depends/?q=foo",
422,
{
"detail": [
{
"loc": ["query", "k"],
"msg": "field required",
"type": "value_error.missing",
}
]
},
),
("/decorator-depends/?k=bar", 200, {"in": "decorator-depends"}),
(
"/router-depends/",
422,
{
"detail": [
{
"loc": ["query", "k"],
"msg": "field required",
"type": "value_error.missing",
}
]
},
),
(
"/router-depends/?q=foo",
422,
{
"detail": [
{
"loc": ["query", "k"],
"msg": "field required",
"type": "value_error.missing",
}
]
},
),
(
"/router-depends/?k=bar",
200,
{"in": "router-depends", "params": {"k": "bar"}},
),
(
"/router-decorator-depends/",
422,
{
"detail": [
{
"loc": ["query", "k"],
"msg": "field required",
"type": "value_error.missing",
}
]
},
),
(
"/router-decorator-depends/?q=foo",
422,
{
"detail": [
{
"loc": ["query", "k"],
"msg": "field required",
"type": "value_error.missing",
}
]
},
),
("/router-decorator-depends/?k=bar", 200, {"in": "router-decorator-depends"}),
],
)
def test_override_with_sub(url, status_code, expected):
app.dependency_overrides[common_parameters] = overrider_dependency_with_sub
response = client.get(url)
assert response.status_code == status_code
assert response.json() == expected
app.dependency_overrides = {}

View File

@@ -0,0 +1,77 @@
from typing import Dict, List, Tuple
import pytest
from fastapi import FastAPI
from pydantic import BaseModel
def test_invalid_sequence():
with pytest.raises(AssertionError):
app = FastAPI()
class Item(BaseModel):
title: str
@app.get("/items/{id}")
def read_items(id: List[Item]):
pass # pragma: no cover
def test_invalid_tuple():
with pytest.raises(AssertionError):
app = FastAPI()
class Item(BaseModel):
title: str
@app.get("/items/{id}")
def read_items(id: Tuple[Item, Item]):
pass # pragma: no cover
def test_invalid_dict():
with pytest.raises(AssertionError):
app = FastAPI()
class Item(BaseModel):
title: str
@app.get("/items/{id}")
def read_items(id: Dict[str, Item]):
pass # pragma: no cover
def test_invalid_simple_list():
with pytest.raises(AssertionError):
app = FastAPI()
@app.get("/items/{id}")
def read_items(id: list):
pass # pragma: no cover
def test_invalid_simple_tuple():
with pytest.raises(AssertionError):
app = FastAPI()
@app.get("/items/{id}")
def read_items(id: tuple):
pass # pragma: no cover
def test_invalid_simple_set():
with pytest.raises(AssertionError):
app = FastAPI()
@app.get("/items/{id}")
def read_items(id: set):
pass # pragma: no cover
def test_invalid_simple_dict():
with pytest.raises(AssertionError):
app = FastAPI()
@app.get("/items/{id}")
def read_items(id: dict):
pass # pragma: no cover

View File

@@ -0,0 +1,53 @@
from typing import Dict, List, Tuple
import pytest
from fastapi import FastAPI, Query
from pydantic import BaseModel
def test_invalid_sequence():
with pytest.raises(AssertionError):
app = FastAPI()
class Item(BaseModel):
title: str
@app.get("/items/")
def read_items(q: List[Item] = Query(None)):
pass # pragma: no cover
def test_invalid_tuple():
with pytest.raises(AssertionError):
app = FastAPI()
class Item(BaseModel):
title: str
@app.get("/items/")
def read_items(q: Tuple[Item, Item] = Query(None)):
pass # pragma: no cover
def test_invalid_dict():
with pytest.raises(AssertionError):
app = FastAPI()
class Item(BaseModel):
title: str
@app.get("/items/")
def read_items(q: Dict[str, Item] = Query(None)):
pass # pragma: no cover
def test_invalid_simple_dict():
with pytest.raises(AssertionError):
app = FastAPI()
class Item(BaseModel):
title: str
@app.get("/items/")
def read_items(q: dict = Query(None)):
pass # pragma: no cover

View File

@@ -64,5 +64,5 @@ def test_security_http_bearer_no_credentials():
def test_security_http_bearer_incorrect_scheme_credentials():
response = client.get("/users/me", headers={"Authorization": "Basic notreally"})
assert response.status_code == 403
assert response.json() == {"detail": "Invalid authentication credentials"}
assert response.status_code == 200
assert response.json() == {"msg": "Create an account first"}

View File

@@ -0,0 +1,120 @@
import pytest
from starlette.testclient import TestClient
from path_params.tutorial005 import app
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/model/{model_name}": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
"summary": "Get Model",
"operationId": "get_model_model__model_name__get",
"parameters": [
{
"required": True,
"schema": {
"title": "Model_Name",
"enum": ["alexnet", "resnet", "lenet"],
},
"name": "model_name",
"in": "path",
}
],
}
}
},
"components": {
"schemas": {
"ValidationError": {
"title": "ValidationError",
"required": ["loc", "msg", "type"],
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
},
"HTTPValidationError": {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": "#/components/schemas/ValidationError"},
}
},
},
}
},
}
def test_openapi():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
@pytest.mark.parametrize(
"url,status_code,expected",
[
(
"/model/alexnet",
200,
{"model_name": "alexnet", "message": "Deep Learning FTW!"},
),
(
"/model/lenet",
200,
{"model_name": "lenet", "message": "LeCNN all the images"},
),
(
"/model/resnet",
200,
{"model_name": "resnet", "message": "Have some residuals"},
),
(
"/model/foo",
422,
{
"detail": [
{
"loc": ["path", "model_name"],
"msg": "value is not a valid enumeration member",
"type": "type_error.enum",
}
]
},
),
],
)
def test_get_enums(url, status_code, expected):
response = client.get(url)
assert response.status_code == status_code
assert response.json() == expected

View File

@@ -0,0 +1,95 @@
from starlette.testclient import TestClient
from query_params.tutorial007 import app
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/items/{item_id}": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
"summary": "Read User Item",
"operationId": "read_user_item_items__item_id__get",
"parameters": [
{
"required": True,
"schema": {"title": "Item_Id", "type": "string"},
"name": "item_id",
"in": "path",
},
{
"required": False,
"schema": {"title": "Limit", "type": "integer"},
"name": "limit",
"in": "query",
},
],
}
}
},
"components": {
"schemas": {
"ValidationError": {
"title": "ValidationError",
"required": ["loc", "msg", "type"],
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
},
"HTTPValidationError": {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": "#/components/schemas/ValidationError"},
}
},
},
}
},
}
def test_openapi():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_read_item():
response = client.get("/items/foo")
assert response.status_code == 200
assert response.json() == {"item_id": "foo", "limit": None}
def test_read_item_query():
response = client.get("/items/foo?limit=5")
assert response.status_code == 200
assert response.json() == {"item_id": "foo", "limit": 5}

View File

@@ -86,3 +86,10 @@ def test_multi_query_values():
response = client.get(url)
assert response.status_code == 200
assert response.json() == {"q": ["foo", "bar"]}
def test_query_no_values():
url = "/items/"
response = client.get(url)
assert response.status_code == 200
assert response.json() == {"q": None}

View File

@@ -0,0 +1,96 @@
from starlette.testclient import TestClient
from query_params_str_validations.tutorial012 import app
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/items/": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
"summary": "Read Items",
"operationId": "read_items_items__get",
"parameters": [
{
"required": False,
"schema": {
"title": "Q",
"type": "array",
"items": {"type": "string"},
"default": ["foo", "bar"],
},
"name": "q",
"in": "query",
}
],
}
}
},
"components": {
"schemas": {
"ValidationError": {
"title": "ValidationError",
"required": ["loc", "msg", "type"],
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
},
"HTTPValidationError": {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": "#/components/schemas/ValidationError"},
}
},
},
}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_default_query_values():
url = "/items/"
response = client.get(url)
assert response.status_code == 200
assert response.json() == {"q": ["foo", "bar"]}
def test_multi_query_values():
url = "/items/?q=baz&q=foobar"
response = client.get(url)
assert response.status_code == 200
assert response.json() == {"q": ["baz", "foobar"]}

View File

@@ -0,0 +1,91 @@
from starlette.testclient import TestClient
from query_params_str_validations.tutorial013 import app
client = TestClient(app)
openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "Fast API", "version": "0.1.0"},
"paths": {
"/items/": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
"content": {"application/json": {"schema": {}}},
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
},
},
},
"summary": "Read Items",
"operationId": "read_items_items__get",
"parameters": [
{
"required": False,
"schema": {"title": "Q", "type": "array"},
"name": "q",
"in": "query",
}
],
}
}
},
"components": {
"schemas": {
"ValidationError": {
"title": "ValidationError",
"required": ["loc", "msg", "type"],
"type": "object",
"properties": {
"loc": {
"title": "Location",
"type": "array",
"items": {"type": "string"},
},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
},
"HTTPValidationError": {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": "#/components/schemas/ValidationError"},
}
},
},
}
},
}
def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_schema
def test_multi_query_values():
url = "/items/?q=foo&q=bar"
response = client.get(url)
assert response.status_code == 200
assert response.json() == {"q": ["foo", "bar"]}
def test_query_no_values():
url = "/items/"
response = client.get(url)
assert response.status_code == 200
assert response.json() == {"q": None}

View File

View File

@@ -0,0 +1,56 @@
from dependency_testing.tutorial001 import (
app,
client,
test_override_in_items,
test_override_in_items_with_params,
test_override_in_items_with_q,
)
def test_override_in_items_run():
test_override_in_items()
def test_override_in_items_with_q_run():
test_override_in_items_with_q()
def test_override_in_items_with_params_run():
test_override_in_items_with_params()
def test_override_in_users():
response = client.get("/users/")
assert response.status_code == 200
assert response.json() == {
"message": "Hello Users!",
"params": {"q": None, "skip": 5, "limit": 10},
}
def test_override_in_users_with_q():
response = client.get("/users/?q=foo")
assert response.status_code == 200
assert response.json() == {
"message": "Hello Users!",
"params": {"q": "foo", "skip": 5, "limit": 10},
}
def test_override_in_users_with_params():
response = client.get("/users/?q=foo&skip=100&limit=200")
assert response.status_code == 200
assert response.json() == {
"message": "Hello Users!",
"params": {"q": "foo", "skip": 5, "limit": 10},
}
def test_normal_app():
app.dependency_overrides = None
response = client.get("/items/?q=foo&skip=100&limit=200")
assert response.status_code == 200
assert response.json() == {
"message": "Hello Items!",
"params": {"q": "foo", "skip": 100, "limit": 200},
}