Compare commits

..

4 Commits

Author SHA1 Message Date
github-actions[bot]
63eb33ce10 🎨 Auto format 2026-06-15 13:33:02 +00:00
Yurii Motov
3b0c595c2e Fix response_model being set to None for non-generator endpoints 2026-06-15 15:31:57 +02:00
Yurii Motov
6ab384c423 Move blocks down 2026-06-15 15:30:11 +02:00
Yurii Motov
2fd2248d40 Add tests for response_model_* params with non-generator Iterable return type 2026-06-15 15:15:25 +02:00
2 changed files with 76 additions and 43 deletions

View File

@@ -930,28 +930,6 @@ def _populate_api_route_state(
route.path = path
route.endpoint = endpoint
route.stream_item_type = None
if isinstance(response_model, DefaultPlaceholder):
return_annotation = get_typed_return_annotation(endpoint)
if lenient_issubclass(return_annotation, Response):
response_model = None
else:
stream_item = get_stream_item_type(return_annotation)
if stream_item is not None:
# Extract item type for JSONL or SSE streaming when
# response_class is DefaultPlaceholder (JSONL) or
# EventSourceResponse (SSE).
# ServerSentEvent is excluded: it's a transport
# wrapper, not a data model, so it shouldn't feed
# into validation or OpenAPI schema generation.
if (
isinstance(response_class, DefaultPlaceholder)
or lenient_issubclass(response_class, EventSourceResponse)
) and not lenient_issubclass(stream_item, ServerSentEvent):
route.stream_item_type = stream_item
response_model = None
else:
response_model = return_annotation
route.response_model = response_model
route.summary = summary
route.response_description = response_description
route.deprecated = deprecated
@@ -987,27 +965,6 @@ def _populate_api_route_state(
if isinstance(status_code, IntEnum):
status_code = int(status_code)
route.status_code = status_code
if route.response_model:
assert is_body_allowed_for_status_code(status_code), (
f"Status code {status_code} must not have a response body"
)
response_name = "Response_" + route.unique_id
route.response_field = create_model_field(
name=response_name,
type_=route.response_model,
mode="serialization",
)
else:
route.response_field = None
if route.stream_item_type:
stream_item_name = "StreamItem_" + route.unique_id
route.stream_item_field = create_model_field(
name=stream_item_name,
type_=route.stream_item_type,
mode="serialization",
)
else:
route.stream_item_field = None
route.dependencies = list(dependencies or [])
route.description = description or inspect.cleandoc(route.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text,
@@ -1059,6 +1016,50 @@ def _populate_api_route_state(
route.is_json_stream = is_generator and isinstance(
response_class, DefaultPlaceholder
)
if isinstance(response_model, DefaultPlaceholder):
return_annotation = get_typed_return_annotation(endpoint)
if lenient_issubclass(return_annotation, Response):
response_model = None
else:
stream_item = get_stream_item_type(return_annotation)
if stream_item is not None and is_generator:
# Extract item type for JSONL or SSE streaming for
# generator endpoints when response_class is
# DefaultPlaceholder (JSONL) or EventSourceResponse
# (SSE).
# ServerSentEvent is excluded: it's a transport
# wrapper, not a data model, so it shouldn't feed
# into validation or OpenAPI schema generation.
if (
isinstance(response_class, DefaultPlaceholder)
or lenient_issubclass(response_class, EventSourceResponse)
) and not lenient_issubclass(stream_item, ServerSentEvent):
route.stream_item_type = stream_item
response_model = None
else:
response_model = return_annotation
route.response_model = response_model
if route.response_model:
assert is_body_allowed_for_status_code(status_code), (
f"Status code {status_code} must not have a response body"
)
response_name = "Response_" + route.unique_id
route.response_field = create_model_field(
name=response_name,
type_=route.response_model,
mode="serialization",
)
else:
route.response_field = None
if route.stream_item_type:
stream_item_name = "StreamItem_" + route.unique_id
route.stream_item_field = create_model_field(
name=stream_item_name,
type_=route.stream_item_type,
mode="serialization",
)
else:
route.stream_item_field = None
class APIRoute(routing.Route):

View File

@@ -1,3 +1,5 @@
from collections.abc import Iterable
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import BaseModel
@@ -65,6 +67,21 @@ def get_exclude_unset_none() -> ModelDefaults:
return ModelDefaults(x=None, y="y")
@app.get("/iterable_exclude_unset", response_model_exclude_unset=True)
def get_iterable_exclude_unset() -> Iterable[ModelDefaults]:
return [ModelDefaults(x=None, y="y")]
@app.get("/iterable_exclude_defaults", response_model_exclude_defaults=True)
def get_iterable_exclude_defaults() -> Iterable[ModelDefaults]:
return [ModelDefaults(x=None, y="y")]
@app.get("/iterable_exclude_none", response_model_exclude_none=True)
def get_iterable_exclude_none() -> Iterable[ModelDefaults]:
return [ModelDefaults(x=None, y="y")]
client = TestClient(app)
@@ -91,3 +108,18 @@ def test_return_exclude_none():
def test_return_exclude_unset_none():
response = client.get("/exclude_unset_none")
assert response.json() == {"y": "y"}
def test_return_iterable_exclude_unset():
response = client.get("/iterable_exclude_unset")
assert response.json() == [{"x": None, "y": "y"}]
def test_return_iterable_exclude_defaults():
response = client.get("/iterable_exclude_defaults")
assert response.json() == [{}]
def test_return_iterable_exclude_none():
response = client.get("/iterable_exclude_none")
assert response.json() == [{"y": "y", "z": "z"}]