mirror of
https://github.com/fastapi/fastapi.git
synced 2026-01-07 21:49:52 -05:00
♻️ Use new Pydantic v2 JSON Schema generator (#9813)
Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
a65281fe09
commit
d4e3dcfa3a
@@ -79,6 +79,7 @@ if PYDANTIC_V2:
|
||||
class ModelField:
|
||||
field_info: FieldInfo
|
||||
name: str
|
||||
mode: Literal["validation", "serialization"] = "validation"
|
||||
|
||||
@property
|
||||
def alias(self) -> str:
|
||||
@@ -178,9 +179,12 @@ if PYDANTIC_V2:
|
||||
field: ModelField,
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
],
|
||||
) -> Dict[str, Any]:
|
||||
# This expects that GenerateJsonSchema was already used to generate the definitions
|
||||
json_schema = schema_generator.generate_inner(field._type_adapter.core_schema)
|
||||
json_schema = field_mapping[(field, field.mode)]
|
||||
if "$ref" not in json_schema:
|
||||
# TODO remove when deprecating Pydantic v1
|
||||
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
|
||||
@@ -197,12 +201,12 @@ if PYDANTIC_V2:
|
||||
fields: List[ModelField],
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
|
||||
inputs = [
|
||||
(field, "validation", field._type_adapter.core_schema) for field in fields
|
||||
(field, field.mode, field._type_adapter.core_schema) for field in fields
|
||||
]
|
||||
_, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type]
|
||||
return definitions # type: ignore[return-value]
|
||||
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs) # type: ignore[arg-type]
|
||||
return field_mapping, definitions # type: ignore[return-value]
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
from fastapi import params
|
||||
@@ -419,6 +423,9 @@ else:
|
||||
field: ModelField,
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
],
|
||||
) -> Dict[str, Any]:
|
||||
# This expects that GenerateJsonSchema was already used to generate the definitions
|
||||
return field_schema( # type: ignore[no-any-return]
|
||||
@@ -434,9 +441,11 @@ else:
|
||||
fields: List[ModelField],
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Dict[str, Any]]]:
|
||||
models = get_flat_models_from_fields(fields, known_models=set())
|
||||
return get_model_definitions(flat_models=models, model_name_map=model_name_map)
|
||||
return {}, get_model_definitions(
|
||||
flat_models=models, model_name_map=model_name_map
|
||||
)
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
return is_pv1_scalar_field(field)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union,
|
||||
from fastapi import routing
|
||||
from fastapi._compat import (
|
||||
GenerateJsonSchema,
|
||||
JsonSchemaValue,
|
||||
ModelField,
|
||||
Undefined,
|
||||
get_compat_model_name_map,
|
||||
@@ -30,6 +31,7 @@ from fastapi.utils import (
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import BaseRoute
|
||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
||||
from typing_extensions import Literal
|
||||
|
||||
validation_error_definition = {
|
||||
"title": "ValidationError",
|
||||
@@ -90,6 +92,9 @@ def get_openapi_operation_parameters(
|
||||
all_route_params: Sequence[ModelField],
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
],
|
||||
) -> List[Dict[str, Any]]:
|
||||
parameters = []
|
||||
for param in all_route_params:
|
||||
@@ -101,6 +106,7 @@ def get_openapi_operation_parameters(
|
||||
field=param,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
)
|
||||
parameter = {
|
||||
"name": param.alias,
|
||||
@@ -123,6 +129,9 @@ def get_openapi_operation_request_body(
|
||||
body_field: Optional[ModelField],
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if not body_field:
|
||||
return None
|
||||
@@ -131,6 +140,7 @@ def get_openapi_operation_request_body(
|
||||
field=body_field,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
)
|
||||
field_info = cast(Body, body_field.field_info)
|
||||
request_media_type = field_info.media_type
|
||||
@@ -198,6 +208,9 @@ def get_openapi_path(
|
||||
operation_ids: Set[str],
|
||||
schema_generator: GenerateJsonSchema,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
],
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
path = {}
|
||||
security_schemes: Dict[str, Any] = {}
|
||||
@@ -228,6 +241,7 @@ def get_openapi_path(
|
||||
all_route_params=all_route_params,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
)
|
||||
parameters.extend(operation_parameters)
|
||||
if parameters:
|
||||
@@ -248,6 +262,7 @@ def get_openapi_path(
|
||||
body_field=route.body_field,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
)
|
||||
if request_body_oai:
|
||||
operation["requestBody"] = request_body_oai
|
||||
@@ -264,6 +279,7 @@ def get_openapi_path(
|
||||
operation_ids=operation_ids,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
)
|
||||
callbacks[callback.name] = {callback.path: cb_path}
|
||||
operation["callbacks"] = callbacks
|
||||
@@ -293,6 +309,7 @@ def get_openapi_path(
|
||||
field=route.response_field,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
)
|
||||
else:
|
||||
response_schema = {}
|
||||
@@ -325,6 +342,7 @@ def get_openapi_path(
|
||||
field=field,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
)
|
||||
media_type = route_response_media_type or "application/json"
|
||||
additional_schema = (
|
||||
@@ -437,7 +455,7 @@ def get_openapi(
|
||||
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
|
||||
model_name_map = get_compat_model_name_map(all_fields)
|
||||
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
|
||||
definitions = get_definitions(
|
||||
field_mapping, definitions = get_definitions(
|
||||
fields=all_fields,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
@@ -449,6 +467,7 @@ def get_openapi(
|
||||
operation_ids=operation_ids,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
)
|
||||
if result:
|
||||
path, security_schemes, path_definitions = result
|
||||
@@ -467,6 +486,7 @@ def get_openapi(
|
||||
operation_ids=operation_ids,
|
||||
schema_generator=schema_generator,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
)
|
||||
if result:
|
||||
path, security_schemes, path_definitions = result
|
||||
|
||||
@@ -446,7 +446,11 @@ class APIRoute(routing.Route):
|
||||
), f"Status code {status_code} must not have a response body"
|
||||
response_name = "Response_" + self.unique_id
|
||||
self.response_field = create_response_field(
|
||||
name=response_name, type_=self.response_model
|
||||
name=response_name,
|
||||
type_=self.response_model,
|
||||
# TODO: This should actually set mode='serialization', just, that changes the schemas
|
||||
# mode="serialization",
|
||||
mode="validation",
|
||||
)
|
||||
# Create a clone of the field, so that a Pydantic submodel is not returned
|
||||
# as is just because it's an instance of a subclass of a more limited class
|
||||
|
||||
@@ -28,6 +28,7 @@ from fastapi._compat import (
|
||||
from fastapi.datastructures import DefaultPlaceholder, DefaultType
|
||||
from pydantic import BaseModel, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import Literal
|
||||
|
||||
if TYPE_CHECKING: # pragma: nocover
|
||||
from .routing import APIRoute
|
||||
@@ -68,6 +69,7 @@ def create_response_field(
|
||||
model_config: Type[BaseConfig] = BaseConfig,
|
||||
field_info: Optional[FieldInfo] = None,
|
||||
alias: Optional[str] = None,
|
||||
mode: Literal["validation", "serialization"] = "validation",
|
||||
) -> ModelField:
|
||||
"""
|
||||
Create a new response field. Raises if type_ is invalid.
|
||||
@@ -80,7 +82,9 @@ def create_response_field(
|
||||
else:
|
||||
field_info = field_info or FieldInfo()
|
||||
kwargs = {"name": name, "field_info": field_info}
|
||||
if not PYDANTIC_V2:
|
||||
if PYDANTIC_V2:
|
||||
kwargs.update({"mode": mode})
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"type_": type_,
|
||||
|
||||
Reference in New Issue
Block a user