Implement support for JSON Lines and Streaming data with yield

This commit is contained in:
Sebastián Ramírez
2026-02-27 01:27:47 +01:00
parent e6ddf0c122
commit 07ab822c7c
3 changed files with 219 additions and 71 deletions

View File

@@ -1,7 +1,17 @@
import dataclasses
import inspect
import sys
from collections.abc import Callable, Mapping, Sequence
from collections.abc import (
AsyncGenerator,
AsyncIterable,
AsyncIterator,
Callable,
Generator,
Iterable,
Iterator,
Mapping,
Sequence,
)
from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass
@@ -251,6 +261,26 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
return get_typed_annotation(annotation, globalns)
_STREAM_ORIGINS = {
AsyncIterable,
AsyncIterator,
AsyncGenerator,
Iterable,
Iterator,
Generator,
}
def get_stream_item_type(annotation: Any) -> Any | None:
origin = get_origin(annotation)
if origin is not None and origin in _STREAM_ORIGINS:
type_args = get_args(annotation)
if type_args:
return type_args[0]
return Any
return None
def get_dependant(
*,
path: str,

View File

@@ -355,25 +355,40 @@ def get_openapi_path(
operation.setdefault("responses", {}).setdefault(status_code, {})[
"description"
] = route.response_description
if route_response_media_type and is_body_allowed_for_status_code(
route.status_code
):
response_schema = {"type": "string"}
if lenient_issubclass(current_response_class, JSONResponse):
if route.response_field:
response_schema = get_schema_from_model_field(
field=route.response_field,
if is_body_allowed_for_status_code(route.status_code):
# Check for JSONL streaming (generator endpoints)
if route.is_json_stream:
jsonl_content: dict[str, Any] = {}
if route.stream_item_field:
item_schema = get_schema_from_model_field(
field=route.stream_item_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
jsonl_content["itemSchema"] = item_schema
else:
response_schema = {}
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {}).setdefault(route_response_media_type, {})[
"schema"
] = response_schema
jsonl_content["itemSchema"] = {}
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {})["application/jsonl"] = jsonl_content
elif route_response_media_type:
response_schema = {"type": "string"}
if lenient_issubclass(current_response_class, JSONResponse):
if route.response_field:
response_schema = get_schema_from_model_field(
field=route.response_field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
else:
response_schema = {}
operation.setdefault("responses", {}).setdefault(
status_code, {}
).setdefault("content", {}).setdefault(
route_response_media_type, {}
)["schema"] = response_schema
if route.responses:
operation_responses = operation.setdefault("responses", {})
for (
@@ -453,9 +468,9 @@ def get_fields_from_routes(
request_fields_from_routes: list[ModelField] = []
callback_flat_models: list[ModelField] = []
for route in routes:
if getattr(route, "include_in_schema", None) and isinstance(
route, routing.APIRoute
):
if not isinstance(route, routing.APIRoute):
continue
if route.include_in_schema:
if route.body_field:
assert isinstance(route.body_field, ModelField), (
"A request body must be a Pydantic Field"
@@ -465,6 +480,8 @@ def get_fields_from_routes(
responses_from_routes.append(route.response_field)
if route.response_fields:
responses_from_routes.extend(route.response_fields.values())
if route.stream_item_field:
responses_from_routes.append(route.stream_item_field)
if route.callbacks:
callback_flat_models.extend(get_fields_from_routes(route.callbacks))
params = get_flat_params(route.dependant)

View File

@@ -11,6 +11,7 @@ from collections.abc import (
Collection,
Coroutine,
Generator,
Iterator,
Mapping,
Sequence,
)
@@ -42,6 +43,7 @@ from fastapi.dependencies.utils import (
get_dependant,
get_flat_dependant,
get_parameterless_sub_dependant,
get_stream_item_type,
get_typed_return_annotation,
solve_dependencies,
)
@@ -66,7 +68,7 @@ from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.responses import JSONResponse, Response, StreamingResponse
from starlette.routing import (
BaseRoute,
Match,
@@ -315,6 +317,24 @@ async def run_endpoint_function(
return await run_in_threadpool(dependant.call, **values)
def _build_response_args(
*, status_code: int | None, solved_result: Any
) -> dict[str, Any]:
response_args: dict[str, Any] = {
"background": solved_result.background_tasks,
}
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
current_status_code = (
status_code if status_code else solved_result.response.status_code
)
if current_status_code is not None:
response_args["status_code"] = current_status_code
if solved_result.response.status_code:
response_args["status_code"] = solved_result.response.status_code
return response_args
def get_request_handler(
dependant: Dependant,
body_field: ModelField | None = None,
@@ -330,6 +350,8 @@ def get_request_handler(
dependency_overrides_provider: Any | None = None,
embed_body_fields: bool = False,
strict_content_type: bool | DefaultPlaceholder = Default(True),
stream_item_field: ModelField | None = None,
is_json_stream: bool = False,
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = dependant.is_coroutine_callable
@@ -427,61 +449,115 @@ def get_request_handler(
embed_body_fields=embed_body_fields,
)
errors = solved_result.errors
assert dependant.call # For types
if not errors:
raw_response = await run_endpoint_function(
dependant=dependant,
values=solved_result.values,
is_coroutine=is_coroutine,
)
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = solved_result.background_tasks
response = raw_response
else:
response_args: dict[str, Any] = {
"background": solved_result.background_tasks
}
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
current_status_code = (
status_code if status_code else solved_result.response.status_code
)
if current_status_code is not None:
response_args["status_code"] = current_status_code
if solved_result.response.status_code:
response_args["status_code"] = solved_result.response.status_code
# Use the fast path (dump_json) when no custom response
# class was set and a response field with a TypeAdapter
# exists. Serializes directly to JSON bytes via Pydantic's
# Rust core, skipping the intermediate Python dict +
# json.dumps() step.
use_dump_json = response_field is not None and isinstance(
response_class, DefaultPlaceholder
)
content = await serialize_response(
field=response_field,
response_content=raw_response,
include=response_model_include,
exclude=response_model_exclude,
by_alias=response_model_by_alias,
exclude_unset=response_model_exclude_unset,
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine,
endpoint_ctx=endpoint_ctx,
dump_json=use_dump_json,
)
if use_dump_json:
response = Response(
content=content,
media_type="application/json",
**response_args,
if is_json_stream:
# Generator endpoint: stream as JSONL
gen = dependant.call(**solved_result.values)
def _serialize_item(item: Any) -> bytes:
if stream_item_field:
value, errors = stream_item_field.validate(
item, {}, loc=("response",)
)
if errors:
ctx = endpoint_ctx or EndpointContext()
raise ResponseValidationError(
errors=errors,
body=item,
endpoint_ctx=ctx,
)
line = stream_item_field.serialize_json(
value,
include=response_model_include,
exclude=response_model_exclude,
by_alias=response_model_by_alias,
exclude_unset=response_model_exclude_unset,
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
)
return line + b"\n"
else:
data = jsonable_encoder(item)
return json.dumps(data).encode("utf-8") + b"\n"
if dependant.is_async_gen_callable:
async def _async_stream_jsonl() -> AsyncIterator[bytes]:
async for item in gen:
yield _serialize_item(item)
stream_content: AsyncIterator[bytes] | Iterator[bytes] = (
_async_stream_jsonl()
)
else:
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""
def _sync_stream_jsonl() -> Iterator[bytes]:
for item in gen:
yield _serialize_item(item)
stream_content = _sync_stream_jsonl()
response = StreamingResponse(
stream_content,
media_type="application/jsonl",
background=solved_result.background_tasks,
)
response.headers.raw.extend(solved_result.response.headers.raw)
elif dependant.is_async_gen_callable or dependant.is_gen_callable:
# Raw streaming with explicit response_class (e.g. StreamingResponse)
gen = dependant.call(**solved_result.values)
response_args = _build_response_args(
status_code=status_code, solved_result=solved_result
)
response = actual_response_class(content=gen, **response_args)
response.headers.raw.extend(solved_result.response.headers.raw)
else:
raw_response = await run_endpoint_function(
dependant=dependant,
values=solved_result.values,
is_coroutine=is_coroutine,
)
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = solved_result.background_tasks
response = raw_response
else:
response_args = _build_response_args(
status_code=status_code, solved_result=solved_result
)
# Use the fast path (dump_json) when no custom response
# class was set and a response field with a TypeAdapter
# exists. Serializes directly to JSON bytes via Pydantic's
# Rust core, skipping the intermediate Python dict +
# json.dumps() step.
use_dump_json = response_field is not None and isinstance(
response_class, DefaultPlaceholder
)
content = await serialize_response(
field=response_field,
response_content=raw_response,
include=response_model_include,
exclude=response_model_exclude,
by_alias=response_model_by_alias,
exclude_unset=response_model_exclude_unset,
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine,
endpoint_ctx=endpoint_ctx,
dump_json=use_dump_json,
)
if use_dump_json:
response = Response(
content=content,
media_type="application/json",
**response_args,
)
else:
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""
response.headers.raw.extend(solved_result.response.headers.raw)
if errors:
validation_error = RequestValidationError(
errors, body=body, endpoint_ctx=endpoint_ctx
@@ -609,12 +685,21 @@ class APIRoute(routing.Route):
) -> None:
self.path = path
self.endpoint = endpoint
self.stream_item_type: Any | None = None
if isinstance(response_model, DefaultPlaceholder):
return_annotation = get_typed_return_annotation(endpoint)
if lenient_issubclass(return_annotation, Response):
response_model = None
else:
response_model = return_annotation
stream_item = get_stream_item_type(return_annotation)
if stream_item is not None:
# Only extract item type for JSONL streaming when no
# explicit response_class (e.g. StreamingResponse) was set
if isinstance(response_class, DefaultPlaceholder):
self.stream_item_type = stream_item
response_model = None
else:
response_model = return_annotation
self.response_model = response_model
self.summary = summary
self.response_description = response_description
@@ -663,6 +748,15 @@ class APIRoute(routing.Route):
)
else:
self.response_field = None # type: ignore
if self.stream_item_type:
stream_item_name = "StreamItem_" + self.unique_id
self.stream_item_field: ModelField | None = create_model_field(
name=stream_item_name,
type_=self.stream_item_type,
mode="serialization",
)
else:
self.stream_item_field = None
self.dependencies = list(dependencies or [])
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
# if a "form feed" character (page break) is found in the description text,
@@ -704,6 +798,11 @@ class APIRoute(routing.Route):
name=self.unique_id,
embed_body_fields=self._embed_body_fields,
)
# Detect generator endpoints that should stream as JSONL
# (only when no explicit response_class like StreamingResponse is set)
self.is_json_stream = isinstance(response_class, DefaultPlaceholder) and (
self.dependant.is_async_gen_callable or self.dependant.is_gen_callable
)
self.app = request_response(self.get_route_handler())
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
@@ -722,6 +821,8 @@ class APIRoute(routing.Route):
dependency_overrides_provider=self.dependency_overrides_provider,
embed_body_fields=self._embed_body_fields,
strict_content_type=self.strict_content_type,
stream_item_field=self.stream_item_field,
is_json_stream=self.is_json_stream,
)
def matches(self, scope: Scope) -> tuple[Match, Scope]: