Extract resolve_nested_model helper to reduce cast() duplication

Consolidates the repeated pattern of resolving a FieldInfo's default
to a Pydantic BaseModel instance (via default or default_factory) into
a single helper. Also tightens update_description_with_new_default
parameter types from Any to object.
This commit is contained in:
Sina Atalay
2026-03-24 19:20:50 +03:00
parent 0857fc1d1b
commit 2552e18c96

View File

@@ -29,6 +29,31 @@ def sanitize_defaults(value: Any) -> Any:
return value
def resolve_nested_model(field_info: FieldInfo) -> pydantic.BaseModel | None:
"""Resolve a FieldInfo's default to a Pydantic model instance.
Why:
Multiple functions need to inspect whether a field's default is a nested
Pydantic model. The default can come from either a direct value or a
factory callable. This helper encapsulates the resolution and isinstance
check, eliminating repeated cast() calls across the module.
Args:
field_info: Pydantic field info to inspect.
Returns:
Model instance if the default is a BaseModel, otherwise None.
"""
if field_info.default_factory is not None:
factory = cast(Callable[[], Any], field_info.default_factory)
obj = factory()
if isinstance(obj, pydantic.BaseModel):
return obj
elif isinstance(field_info.default, pydantic.BaseModel):
return field_info.default
return None
def create_variant_pydantic_model[T: pydantic.BaseModel](
*,
variant_name: str,
@@ -155,15 +180,7 @@ def validate_defaults_against_base(
continue
base_field_info = base_fields[field_name]
nested_obj: pydantic.BaseModel | None = None
if base_field_info.default_factory is not None:
factory = cast(Callable[[], Any], base_field_info.default_factory)
obj = factory()
if isinstance(obj, pydantic.BaseModel):
nested_obj = obj
elif isinstance(base_field_info.default, pydantic.BaseModel):
nested_obj = base_field_info.default
nested_obj = resolve_nested_model(base_field_info)
if nested_obj is not None:
nested_fields = type(nested_obj).model_fields
@@ -195,8 +212,8 @@ def generate_model_name(variant_name: str, class_name_suffix: str) -> str:
def update_description_with_new_default(
original_description: str | None,
old_default: Any,
new_default: Any,
old_default: object,
new_default: object,
) -> str | None:
"""Update field description to reflect new default value.
@@ -331,14 +348,9 @@ def create_nested_model_variant_model(
if isinstance(new_value, dict):
# Check if this field is a nested Pydantic model
nested_obj = None
if base_field_info.default_factory is not None:
factory = cast(Callable[[], Any], base_field_info.default_factory)
nested_obj = factory()
elif isinstance(base_field_info.default, pydantic.BaseModel):
nested_obj = base_field_info.default
nested_obj = resolve_nested_model(base_field_info)
if nested_obj is not None and isinstance(nested_obj, pydantic.BaseModel):
if nested_obj is not None:
# Recursively create nested field spec
field_specs[field_name] = create_nested_field_spec(
new_value, base_field_info
@@ -381,17 +393,7 @@ def create_nested_field_spec(
Returns:
Tuple of variant class annotation and Field with default_factory.
"""
# Get the base nested object - could be from default or default_factory
base_nested_obj: pydantic.BaseModel | None = None
if base_field_info.default_factory is not None:
# Create an instance using the factory
# Cast to proper callable type to satisfy type checker
factory = cast(Callable[[], Any], base_field_info.default_factory)
base_nested_obj = cast(pydantic.BaseModel, factory())
elif isinstance(base_field_info.default, pydantic.BaseModel):
# The default is already a Pydantic model instance
base_nested_obj = base_field_info.default
base_nested_obj = resolve_nested_model(base_field_info)
if base_nested_obj is not None:
# Create a variant class with updated field specs and descriptions