Files
fastapi/fastapi/dependencies/utils.py

1055 lines
38 KiB
Python

import dataclasses
import inspect
import sys
from collections.abc import (
AsyncGenerator,
AsyncIterable,
AsyncIterator,
Callable,
Generator,
Iterable,
Iterator,
Mapping,
Sequence,
)
from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass
from typing import (
Annotated,
Any,
ForwardRef,
Literal,
Union,
cast,
get_args,
get_origin,
)
from fastapi import params
from fastapi._compat import (
ModelField,
RequiredParam,
Undefined,
copy_field_info,
create_body_model,
evaluate_forwardref,
field_annotation_is_scalar,
field_annotation_is_scalar_sequence,
field_annotation_is_sequence,
get_cached_model_fields,
get_missing_field_error,
is_bytes_or_nonable_bytes_annotation,
is_bytes_sequence_annotation,
is_scalar_field,
is_uploadfile_or_nonable_uploadfile_annotation,
is_uploadfile_sequence_annotation,
lenient_issubclass,
sequence_types,
serialize_sequence_value,
value_is_sequence,
)
from fastapi.background import BackgroundTasks
from fastapi.concurrency import (
asynccontextmanager,
contextmanager_in_threadpool,
)
from fastapi.dependencies.models import Dependant
from fastapi.exceptions import DependencyScopeError
from fastapi.logger import logger
from fastapi.security.oauth2 import SecurityScopes
from fastapi.types import DependencyCacheKey
from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel, Json
from pydantic.fields import FieldInfo
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import (
FormData,
Headers,
ImmutableMultiDict,
QueryParams,
UploadFile,
)
from starlette.requests import HTTPConnection, Request
from starlette.responses import Response
from starlette.websockets import WebSocket
from typing_inspection.typing_objects import is_typealiastype
multipart_not_installed_error = (
'Form data requires "python-multipart" to be installed. \n'
'You can install "python-multipart" with: \n\n'
"pip install python-multipart\n"
)
multipart_incorrect_install_error = (
'Form data requires "python-multipart" to be installed. '
'It seems you installed "multipart" instead. \n'
'You can remove "multipart" with: \n\n'
"pip uninstall multipart\n\n"
'And then install "python-multipart" with: \n\n'
"pip install python-multipart\n"
)
def ensure_multipart_is_installed() -> None:
try:
from python_multipart import __version__
# Import an attribute that can be mocked/deleted in testing
assert __version__ > "0.0.12"
except (ImportError, AssertionError):
try:
# __version__ is available in both multiparts, and can be mocked
from multipart import __version__ # type: ignore[no-redef,import-untyped]
assert __version__
try:
# parse_options_header is only available in the right multipart
from multipart.multipart import ( # type: ignore[import-untyped]
parse_options_header,
)
assert parse_options_header
except ImportError:
logger.error(multipart_incorrect_install_error)
raise RuntimeError(multipart_incorrect_install_error) from None
except ImportError:
logger.error(multipart_not_installed_error)
raise RuntimeError(multipart_not_installed_error) from None
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
assert callable(depends.dependency), (
"A parameter-less dependency must have a callable dependency"
)
own_oauth_scopes: list[str] = []
if isinstance(depends, params.Security) and depends.scopes:
own_oauth_scopes.extend(depends.scopes)
return get_dependant(
path=path,
call=depends.dependency,
scope=depends.scope,
own_oauth_scopes=own_oauth_scopes,
)
def get_flat_dependant(
dependant: Dependant,
*,
skip_repeats: bool = False,
visited: list[DependencyCacheKey] | None = None,
parent_oauth_scopes: list[str] | None = None,
) -> Dependant:
if visited is None:
visited = []
visited.append(dependant.cache_key)
use_parent_oauth_scopes = (parent_oauth_scopes or []) + (
dependant.oauth_scopes or []
)
flat_dependant = Dependant(
path_params=dependant.path_params.copy(),
query_params=dependant.query_params.copy(),
header_params=dependant.header_params.copy(),
cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(),
name=dependant.name,
call=dependant.call,
request_param_name=dependant.request_param_name,
websocket_param_name=dependant.websocket_param_name,
http_connection_param_name=dependant.http_connection_param_name,
response_param_name=dependant.response_param_name,
background_tasks_param_name=dependant.background_tasks_param_name,
security_scopes_param_name=dependant.security_scopes_param_name,
own_oauth_scopes=dependant.own_oauth_scopes,
parent_oauth_scopes=use_parent_oauth_scopes,
use_cache=dependant.use_cache,
path=dependant.path,
scope=dependant.scope,
)
for sub_dependant in dependant.dependencies:
if skip_repeats and sub_dependant.cache_key in visited:
continue
flat_sub = get_flat_dependant(
sub_dependant,
skip_repeats=skip_repeats,
visited=visited,
parent_oauth_scopes=flat_dependant.oauth_scopes,
)
flat_dependant.dependencies.append(flat_sub)
flat_dependant.path_params.extend(flat_sub.path_params)
flat_dependant.query_params.extend(flat_sub.query_params)
flat_dependant.header_params.extend(flat_sub.header_params)
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
flat_dependant.body_params.extend(flat_sub.body_params)
flat_dependant.dependencies.extend(flat_sub.dependencies)
return flat_dependant
def _get_flat_fields_from_params(fields: list[ModelField]) -> list[ModelField]:
if not fields:
return fields
first_field = fields[0]
if len(fields) == 1 and lenient_issubclass(
first_field.field_info.annotation, BaseModel
):
fields_to_extract = get_cached_model_fields(first_field.field_info.annotation)
return fields_to_extract
return fields
def get_flat_params(dependant: Dependant) -> list[ModelField]:
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
header_params = _get_flat_fields_from_params(flat_dependant.header_params)
cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
return path_params + query_params + header_params + cookie_params
def _get_signature(call: Callable[..., Any]) -> inspect.Signature:
try:
signature = inspect.signature(call, eval_str=True)
except NameError:
# Handle type annotations with if TYPE_CHECKING, not used by FastAPI
# e.g. dependency return types
if sys.version_info >= (3, 14):
from annotationlib import Format
signature = inspect.signature(call, annotation_format=Format.FORWARDREF)
else:
signature = inspect.signature(call)
return signature
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = _get_signature(call)
unwrapped = inspect.unwrap(call)
globalns = getattr(unwrapped, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param.annotation, globalns),
)
for param in signature.parameters.values()
]
typed_signature = inspect.Signature(typed_params)
return typed_signature
def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
if annotation is type(None):
return None
return annotation
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
signature = _get_signature(call)
unwrapped = inspect.unwrap(call)
annotation = signature.return_annotation
if annotation is inspect.Signature.empty:
return None
globalns = getattr(unwrapped, "__globals__", {})
return get_typed_annotation(annotation, globalns)
_STREAM_ORIGINS = {
AsyncIterable,
AsyncIterator,
AsyncGenerator,
Iterable,
Iterator,
Generator,
}
def get_stream_item_type(annotation: Any) -> Any | None:
origin = get_origin(annotation)
if origin is not None and origin in _STREAM_ORIGINS:
type_args = get_args(annotation)
if type_args:
return type_args[0]
return Any
return None
def get_dependant(
*,
path: str,
call: Callable[..., Any],
name: str | None = None,
own_oauth_scopes: list[str] | None = None,
parent_oauth_scopes: list[str] | None = None,
use_cache: bool = True,
scope: Literal["function", "request"] | None = None,
) -> Dependant:
dependant = Dependant(
call=call,
name=name,
path=path,
use_cache=use_cache,
scope=scope,
own_oauth_scopes=own_oauth_scopes,
parent_oauth_scopes=parent_oauth_scopes,
)
current_scopes = (parent_oauth_scopes or []) + (own_oauth_scopes or [])
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
for param_name, param in signature_params.items():
is_path_param = param_name in path_param_names
param_details = analyze_param(
param_name=param_name,
annotation=param.annotation,
value=param.default,
is_path_param=is_path_param,
)
if param_details.depends is not None:
assert param_details.depends.dependency
if (
(dependant.is_gen_callable or dependant.is_async_gen_callable)
and dependant.computed_scope == "request"
and param_details.depends.scope == "function"
):
assert dependant.call
raise DependencyScopeError(
f'The dependency "{dependant.call.__name__}" has a scope of '
'"request", it cannot depend on dependencies with scope "function".'
)
sub_own_oauth_scopes: list[str] = []
if isinstance(param_details.depends, params.Security):
if param_details.depends.scopes:
sub_own_oauth_scopes = list(param_details.depends.scopes)
sub_dependant = get_dependant(
path=path,
call=param_details.depends.dependency,
name=param_name,
own_oauth_scopes=sub_own_oauth_scopes,
parent_oauth_scopes=current_scopes,
use_cache=param_details.depends.use_cache,
scope=param_details.depends.scope,
)
dependant.dependencies.append(sub_dependant)
continue
if add_non_field_param_to_dependency(
param_name=param_name,
type_annotation=param_details.type_annotation,
dependant=dependant,
):
assert param_details.field is None, (
f"Cannot specify multiple FastAPI annotations for {param_name!r}"
)
continue
assert param_details.field is not None
if isinstance(param_details.field.field_info, params.Body):
dependant.body_params.append(param_details.field)
else:
add_param_to_fields(field=param_details.field, dependant=dependant)
return dependant
def add_non_field_param_to_dependency(
*, param_name: str, type_annotation: Any, dependant: Dependant
) -> bool | None:
if lenient_issubclass(type_annotation, Request):
dependant.request_param_name = param_name
return True
elif lenient_issubclass(type_annotation, WebSocket):
dependant.websocket_param_name = param_name
return True
elif lenient_issubclass(type_annotation, HTTPConnection):
dependant.http_connection_param_name = param_name
return True
elif lenient_issubclass(type_annotation, Response):
dependant.response_param_name = param_name
return True
elif lenient_issubclass(type_annotation, StarletteBackgroundTasks):
dependant.background_tasks_param_name = param_name
return True
elif lenient_issubclass(type_annotation, SecurityScopes):
dependant.security_scopes_param_name = param_name
return True
return None
@dataclass
class ParamDetails:
type_annotation: Any
depends: params.Depends | None
field: ModelField | None
def analyze_param(
*,
param_name: str,
annotation: Any,
value: Any,
is_path_param: bool,
) -> ParamDetails:
field_info = None
depends = None
type_annotation: Any = Any
use_annotation: Any = Any
if is_typealiastype(annotation):
# unpack in case PEP 695 type syntax is used
annotation = annotation.__value__
if annotation is not inspect.Signature.empty:
use_annotation = annotation
type_annotation = annotation
# Extract Annotated info
if get_origin(use_annotation) is Annotated:
annotated_args = get_args(annotation)
type_annotation = annotated_args[0]
fastapi_annotations = [
arg
for arg in annotated_args[1:]
if isinstance(arg, (FieldInfo, params.Depends))
]
fastapi_specific_annotations = [
arg
for arg in fastapi_annotations
if isinstance(
arg,
(
params.Param,
params.Body,
params.Depends,
),
)
]
if fastapi_specific_annotations:
fastapi_annotation: FieldInfo | params.Depends | None = (
fastapi_specific_annotations[-1]
)
else:
fastapi_annotation = None
# Set default for Annotated FieldInfo
if isinstance(fastapi_annotation, FieldInfo):
# Copy `field_info` because we mutate `field_info.default` below.
field_info = copy_field_info(
field_info=fastapi_annotation,
annotation=use_annotation,
)
assert (
field_info.default == Undefined or field_info.default == RequiredParam
), (
f"`{field_info.__class__.__name__}` default value cannot be set in"
f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
)
if value is not inspect.Signature.empty:
assert not is_path_param, "Path parameters cannot have default values"
field_info.default = value
else:
field_info.default = RequiredParam
# Get Annotated Depends
elif isinstance(fastapi_annotation, params.Depends):
depends = fastapi_annotation
# Get Depends from default value
if isinstance(value, params.Depends):
assert depends is None, (
"Cannot specify `Depends` in `Annotated` and default value"
f" together for {param_name!r}"
)
assert field_info is None, (
"Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
f" default value together for {param_name!r}"
)
depends = value
# Get FieldInfo from default value
elif isinstance(value, FieldInfo):
assert field_info is None, (
"Cannot specify FastAPI annotations in `Annotated` and default value"
f" together for {param_name!r}"
)
field_info = value
if isinstance(field_info, FieldInfo):
field_info.annotation = type_annotation
# Get Depends from type annotation
if depends is not None and depends.dependency is None:
# Copy `depends` before mutating it
depends = copy(depends)
depends = dataclasses.replace(depends, dependency=type_annotation)
# Handle non-param type annotations like Request
# Only apply special handling when there's no explicit Depends - if there's a Depends,
# the dependency will be called and its return value used instead of the special injection
if depends is None and lenient_issubclass(
type_annotation,
(
Request,
WebSocket,
HTTPConnection,
Response,
StarletteBackgroundTasks,
SecurityScopes,
),
):
assert field_info is None, (
f"Cannot specify FastAPI annotation for type {type_annotation!r}"
)
# Handle default assignations, neither field_info nor depends was not found in Annotated nor default value
elif field_info is None and depends is None:
default_value = value if value is not inspect.Signature.empty else RequiredParam
if is_path_param:
# We might check here that `default_value is RequiredParam`, but the fact is that the same
# parameter might sometimes be a path parameter and sometimes not. See
# `tests/test_infer_param_optionality.py` for an example.
field_info = params.Path(annotation=use_annotation)
elif is_uploadfile_or_nonable_uploadfile_annotation(
type_annotation
) or is_uploadfile_sequence_annotation(type_annotation):
field_info = params.File(annotation=use_annotation, default=default_value)
elif not field_annotation_is_scalar(annotation=type_annotation):
field_info = params.Body(annotation=use_annotation, default=default_value)
else:
field_info = params.Query(annotation=use_annotation, default=default_value)
field = None
# It's a field_info, not a dependency
if field_info is not None:
# Handle field_info.in_
if is_path_param:
assert isinstance(field_info, params.Path), (
f"Cannot use `{field_info.__class__.__name__}` for path param"
f" {param_name!r}"
)
elif (
isinstance(field_info, params.Param)
and getattr(field_info, "in_", None) is None
):
field_info.in_ = params.ParamTypes.query
use_annotation_from_field_info = use_annotation
if isinstance(field_info, params.Form):
ensure_multipart_is_installed()
if not field_info.alias and getattr(field_info, "convert_underscores", None):
alias = param_name.replace("_", "-")
else:
alias = field_info.alias or param_name
field_info.alias = alias
field = create_model_field(
name=param_name,
type_=use_annotation_from_field_info,
default=field_info.default,
alias=alias,
field_info=field_info,
)
if is_path_param:
assert is_scalar_field(field=field), (
"Path params must be of one of the supported types"
)
elif isinstance(field_info, params.Query):
assert (
is_scalar_field(field)
or field_annotation_is_scalar_sequence(field.field_info.annotation)
or lenient_issubclass(field.field_info.annotation, BaseModel)
), f"Query parameter {param_name!r} must be one of the supported types"
return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
field_info = field.field_info
field_info_in = getattr(field_info, "in_", None)
if field_info_in == params.ParamTypes.path:
dependant.path_params.append(field)
elif field_info_in == params.ParamTypes.query:
dependant.query_params.append(field)
elif field_info_in == params.ParamTypes.header:
dependant.header_params.append(field)
else:
assert field_info_in == params.ParamTypes.cookie, (
f"non-body parameters must be in path, query, header or cookie: {field.name}"
)
dependant.cookie_params.append(field)
async def _solve_generator(
*, dependant: Dependant, stack: AsyncExitStack, sub_values: dict[str, Any]
) -> Any:
assert dependant.call
if dependant.is_async_gen_callable:
cm = asynccontextmanager(dependant.call)(**sub_values)
elif dependant.is_gen_callable:
cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
return await stack.enter_async_context(cm)
@dataclass
class SolvedDependency:
values: dict[str, Any]
errors: list[Any]
background_tasks: StarletteBackgroundTasks | None
response: Response
dependency_cache: dict[DependencyCacheKey, Any]
async def solve_dependencies(
*,
request: Request | WebSocket,
dependant: Dependant,
body: dict[str, Any] | FormData | None = None,
background_tasks: StarletteBackgroundTasks | None = None,
response: Response | None = None,
dependency_overrides_provider: Any | None = None,
dependency_cache: dict[DependencyCacheKey, Any] | None = None,
# TODO: remove this parameter later, no longer used, not removing it yet as some
# people might be monkey patching this function (although that's not supported)
async_exit_stack: AsyncExitStack,
embed_body_fields: bool,
) -> SolvedDependency:
request_astack = request.scope.get("fastapi_inner_astack")
assert isinstance(request_astack, AsyncExitStack), (
"fastapi_inner_astack not found in request scope"
)
function_astack = request.scope.get("fastapi_function_astack")
assert isinstance(function_astack, AsyncExitStack), (
"fastapi_function_astack not found in request scope"
)
values: dict[str, Any] = {}
errors: list[Any] = []
if response is None:
response = Response()
del response.headers["content-length"]
response.status_code = None # type: ignore
if dependency_cache is None:
dependency_cache = {}
for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
call = sub_dependant.call
use_sub_dependant = sub_dependant
if (
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
):
original_call = sub_dependant.call
call = getattr(
dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call)
use_path: str = sub_dependant.path # type: ignore
use_sub_dependant = get_dependant(
path=use_path,
call=call,
name=sub_dependant.name,
parent_oauth_scopes=sub_dependant.oauth_scopes,
scope=sub_dependant.scope,
)
solved_result = await solve_dependencies(
request=request,
dependant=use_sub_dependant,
body=body,
background_tasks=background_tasks,
response=response,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack,
embed_body_fields=embed_body_fields,
)
background_tasks = solved_result.background_tasks
if solved_result.errors:
errors.extend(solved_result.errors)
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif (
use_sub_dependant.is_gen_callable or use_sub_dependant.is_async_gen_callable
):
use_astack = request_astack
if sub_dependant.scope == "function":
use_astack = function_astack
solved = await _solve_generator(
dependant=use_sub_dependant,
stack=use_astack,
sub_values=solved_result.values,
)
elif use_sub_dependant.is_coroutine_callable:
solved = await call(**solved_result.values)
else:
solved = await run_in_threadpool(call, **solved_result.values)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:
dependency_cache[sub_dependant.cache_key] = solved
path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params
)
query_values, query_errors = request_params_to_args(
dependant.query_params, request.query_params
)
header_values, header_errors = request_params_to_args(
dependant.header_params, request.headers
)
cookie_values, cookie_errors = request_params_to_args(
dependant.cookie_params, request.cookies
)
values.update(path_values)
values.update(query_values)
values.update(header_values)
values.update(cookie_values)
errors += path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params:
(
body_values,
body_errors,
) = await request_body_to_args( # body_params checked above
body_fields=dependant.body_params,
received_body=body,
embed_body_fields=embed_body_fields,
)
values.update(body_values)
errors.extend(body_errors)
if dependant.http_connection_param_name:
values[dependant.http_connection_param_name] = request
if dependant.request_param_name and isinstance(request, Request):
values[dependant.request_param_name] = request
elif dependant.websocket_param_name and isinstance(request, WebSocket):
values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name:
if background_tasks is None:
background_tasks = BackgroundTasks()
values[dependant.background_tasks_param_name] = background_tasks
if dependant.response_param_name:
values[dependant.response_param_name] = response
if dependant.security_scopes_param_name:
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.oauth_scopes
)
return SolvedDependency(
values=values,
errors=errors,
background_tasks=background_tasks,
response=response,
dependency_cache=dependency_cache,
)
def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: dict[str, Any], loc: tuple[str, ...]
) -> tuple[Any, list[Any]]:
if value is None:
if field.field_info.is_required():
return None, [get_missing_field_error(loc=loc)]
else:
return deepcopy(field.default), []
return field.validate(value, values, loc=loc)
def _is_json_field(field: ModelField) -> bool:
return any(type(item) is Json for item in field.field_info.metadata)
def _get_multidict_value(
field: ModelField, values: Mapping[str, Any], alias: str | None = None
) -> Any:
alias = alias or get_validation_alias(field)
if (
(not _is_json_field(field))
and field_annotation_is_sequence(field.field_info.annotation)
and isinstance(values, (ImmutableMultiDict, Headers))
):
value = values.getlist(alias)
else:
value = values.get(alias, None)
if (
value is None
or (
isinstance(field.field_info, params.Form)
and isinstance(value, str) # For type checks
and value == ""
)
or (
field_annotation_is_sequence(field.field_info.annotation)
and len(value) == 0
)
):
if field.field_info.is_required():
return
else:
return deepcopy(field.default)
return value
def request_params_to_args(
fields: Sequence[ModelField],
received_params: Mapping[str, Any] | QueryParams | Headers,
) -> tuple[dict[str, Any], list[Any]]:
values: dict[str, Any] = {}
errors: list[dict[str, Any]] = []
if not fields:
return values, errors
first_field = fields[0]
fields_to_extract = fields
single_not_embedded_field = False
default_convert_underscores = True
if len(fields) == 1 and lenient_issubclass(
first_field.field_info.annotation, BaseModel
):
fields_to_extract = get_cached_model_fields(first_field.field_info.annotation)
single_not_embedded_field = True
# If headers are in a Pydantic model, the way to disable convert_underscores
# would be with Header(convert_underscores=False) at the Pydantic model level
default_convert_underscores = getattr(
first_field.field_info, "convert_underscores", True
)
params_to_process: dict[str, Any] = {}
processed_keys = set()
for field in fields_to_extract:
alias = None
if isinstance(received_params, Headers):
# Handle fields extracted from a Pydantic Model for a header, each field
# doesn't have a FieldInfo of type Header with the default convert_underscores=True
convert_underscores = getattr(
field.field_info, "convert_underscores", default_convert_underscores
)
if convert_underscores:
alias = get_validation_alias(field)
if alias == field.name:
alias = alias.replace("_", "-")
value = _get_multidict_value(field, received_params, alias=alias)
if value is not None:
params_to_process[get_validation_alias(field)] = value
processed_keys.add(alias or get_validation_alias(field))
for key in received_params.keys():
if key not in processed_keys:
if hasattr(received_params, "getlist"):
value = received_params.getlist(key)
if isinstance(value, list) and (len(value) == 1):
params_to_process[key] = value[0]
else:
params_to_process[key] = value
else:
params_to_process[key] = received_params.get(key)
if single_not_embedded_field:
field_info = first_field.field_info
assert isinstance(field_info, params.Param), (
"Params must be subclasses of Param"
)
loc: tuple[str, ...] = (field_info.in_.value,)
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=params_to_process, values=values, loc=loc
)
return {first_field.name: v_}, errors_
for field in fields:
value = _get_multidict_value(field, received_params)
field_info = field.field_info
assert isinstance(field_info, params.Param), (
"Params must be subclasses of Param"
)
loc = (field_info.in_.value, get_validation_alias(field))
v_, errors_ = _validate_value_with_model_field(
field=field, value=value, values=values, loc=loc
)
if errors_:
errors.extend(errors_)
else:
values[field.name] = v_
return values, errors
def is_union_of_base_models(field_type: Any) -> bool:
"""Check if field type is a Union where all members are BaseModel subclasses."""
from fastapi.types import UnionType
origin = get_origin(field_type)
# Check if it's a Union type (covers both typing.Union and types.UnionType in Python 3.10+)
if origin is not Union and origin is not UnionType:
return False
union_args = get_args(field_type)
for arg in union_args:
if not lenient_issubclass(arg, BaseModel):
return False
return True
def _should_embed_body_fields(fields: list[ModelField]) -> bool:
if not fields:
return False
# More than one dependency could have the same field, it would show up as multiple
# fields but it's the same one, so count them by name
body_param_names_set = {field.name for field in fields}
# A top level field has to be a single field, not multiple
if len(body_param_names_set) > 1:
return True
first_field = fields[0]
# If it explicitly specifies it is embedded, it has to be embedded
if getattr(first_field.field_info, "embed", None):
return True
# If it's a Form (or File) field, it has to be a BaseModel (or a union of BaseModels) to be top level
# otherwise it has to be embedded, so that the key value pair can be extracted
if (
isinstance(first_field.field_info, params.Form)
and not lenient_issubclass(first_field.field_info.annotation, BaseModel)
and not is_union_of_base_models(first_field.field_info.annotation)
):
return True
return False
async def _extract_form_body(
body_fields: list[ModelField],
received_body: FormData,
) -> dict[str, Any]:
values = {}
for field in body_fields:
value = _get_multidict_value(field, received_body)
field_info = field.field_info
if (
isinstance(field_info, params.File)
and is_bytes_or_nonable_bytes_annotation(field.field_info.annotation)
and isinstance(value, UploadFile)
):
value = await value.read()
elif (
is_bytes_sequence_annotation(field.field_info.annotation)
and isinstance(field_info, params.File)
and value_is_sequence(value)
):
# For types
assert isinstance(value, sequence_types)
results: list[bytes | str] = []
for sub_value in value:
results.append(await sub_value.read())
value = serialize_sequence_value(field=field, value=results)
if value is not None:
values[get_validation_alias(field)] = value
field_aliases = {get_validation_alias(field) for field in body_fields}
for key in received_body.keys():
if key not in field_aliases:
param_values = received_body.getlist(key)
if len(param_values) == 1:
values[key] = param_values[0]
else:
values[key] = param_values
return values
async def request_body_to_args(
body_fields: list[ModelField],
received_body: dict[str, Any] | FormData | None,
embed_body_fields: bool,
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
values: dict[str, Any] = {}
errors: list[dict[str, Any]] = []
assert body_fields, "request_body_to_args() should be called with fields"
single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
first_field = body_fields[0]
body_to_process = received_body
fields_to_extract: list[ModelField] = body_fields
if (
single_not_embedded_field
and lenient_issubclass(first_field.field_info.annotation, BaseModel)
and isinstance(received_body, FormData)
):
fields_to_extract = get_cached_model_fields(first_field.field_info.annotation)
if isinstance(received_body, FormData):
body_to_process = await _extract_form_body(fields_to_extract, received_body)
if single_not_embedded_field:
loc: tuple[str, ...] = ("body",)
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=body_to_process, values=values, loc=loc
)
return {first_field.name: v_}, errors_
for field in body_fields:
loc = ("body", get_validation_alias(field))
value: Any | None = None
if body_to_process is not None:
try:
value = body_to_process.get(get_validation_alias(field))
# If the received body is a list, not a dict
except AttributeError:
errors.append(get_missing_field_error(loc))
continue
v_, errors_ = _validate_value_with_model_field(
field=field, value=value, values=values, loc=loc
)
if errors_:
errors.extend(errors_)
else:
values[field.name] = v_
return values, errors
def get_body_field(
*, flat_dependant: Dependant, name: str, embed_body_fields: bool
) -> ModelField | None:
"""
Get a ModelField representing the request body for a path operation, combining
all body parameters into a single field if necessary.
Used to check if it's form data (with `isinstance(body_field, params.Form)`)
or JSON and to generate the JSON Schema for a request body.
This is **not** used to validate/parse the request body, that's done with each
individual body parameter.
"""
if not flat_dependant.body_params:
return None
first_param = flat_dependant.body_params[0]
if not embed_body_fields:
return first_param
model_name = "Body_" + name
BodyModel = create_body_model(
fields=flat_dependant.body_params, model_name=model_name
)
required = any(
True for f in flat_dependant.body_params if f.field_info.is_required()
)
BodyFieldInfo_kwargs: dict[str, Any] = {
"annotation": BodyModel,
"alias": "body",
}
if not required:
BodyFieldInfo_kwargs["default"] = None
if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
BodyFieldInfo: type[params.Body] = params.File
elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
BodyFieldInfo = params.Form
else:
BodyFieldInfo = params.Body
body_param_media_types = [
f.field_info.media_type
for f in flat_dependant.body_params
if isinstance(f.field_info, params.Body)
]
if len(set(body_param_media_types)) == 1:
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
final_field = create_model_field(
name="body",
type_=BodyModel,
alias="body",
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
)
return final_field
def get_validation_alias(field: ModelField) -> str:
va = getattr(field, "validation_alias", None)
return va or field.alias