diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index ab18ec2db6..8fcf1a5b3c 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -1,7 +1,17 @@ import dataclasses import inspect import sys -from collections.abc import Callable, Mapping, Sequence +from collections.abc import ( + AsyncGenerator, + AsyncIterable, + AsyncIterator, + Callable, + Generator, + Iterable, + Iterator, + Mapping, + Sequence, +) from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy from dataclasses import dataclass @@ -251,6 +261,26 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any: return get_typed_annotation(annotation, globalns) +_STREAM_ORIGINS = { + AsyncIterable, + AsyncIterator, + AsyncGenerator, + Iterable, + Iterator, + Generator, +} + + +def get_stream_item_type(annotation: Any) -> Any | None: + origin = get_origin(annotation) + if origin is not None and origin in _STREAM_ORIGINS: + type_args = get_args(annotation) + if type_args: + return type_args[0] + return Any + return None + + def get_dependant( *, path: str, diff --git a/fastapi/openapi/utils.py b/fastapi/openapi/utils.py index 812003aee3..3ddc0c14a9 100644 --- a/fastapi/openapi/utils.py +++ b/fastapi/openapi/utils.py @@ -355,25 +355,40 @@ def get_openapi_path( operation.setdefault("responses", {}).setdefault(status_code, {})[ "description" ] = route.response_description - if route_response_media_type and is_body_allowed_for_status_code( - route.status_code - ): - response_schema = {"type": "string"} - if lenient_issubclass(current_response_class, JSONResponse): - if route.response_field: - response_schema = get_schema_from_model_field( - field=route.response_field, + if is_body_allowed_for_status_code(route.status_code): + # Check for JSONL streaming (generator endpoints) + if route.is_json_stream: + jsonl_content: dict[str, Any] = {} + if route.stream_item_field: + item_schema = get_schema_from_model_field( + field=route.stream_item_field, model_name_map=model_name_map, field_mapping=field_mapping, separate_input_output_schemas=separate_input_output_schemas, ) + jsonl_content["itemSchema"] = item_schema else: - response_schema = {} - operation.setdefault("responses", {}).setdefault( - status_code, {} - ).setdefault("content", {}).setdefault(route_response_media_type, {})[ - "schema" - ] = response_schema + jsonl_content["itemSchema"] = {} + operation.setdefault("responses", {}).setdefault( + status_code, {} + ).setdefault("content", {})["application/jsonl"] = jsonl_content + elif route_response_media_type: + response_schema = {"type": "string"} + if lenient_issubclass(current_response_class, JSONResponse): + if route.response_field: + response_schema = get_schema_from_model_field( + field=route.response_field, + model_name_map=model_name_map, + field_mapping=field_mapping, + separate_input_output_schemas=separate_input_output_schemas, + ) + else: + response_schema = {} + operation.setdefault("responses", {}).setdefault( + status_code, {} + ).setdefault("content", {}).setdefault( + route_response_media_type, {} + )["schema"] = response_schema if route.responses: operation_responses = operation.setdefault("responses", {}) for ( @@ -453,9 +468,9 @@ def get_fields_from_routes( request_fields_from_routes: list[ModelField] = [] callback_flat_models: list[ModelField] = [] for route in routes: - if getattr(route, "include_in_schema", None) and isinstance( - route, routing.APIRoute - ): + if not isinstance(route, routing.APIRoute): + continue + if route.include_in_schema: if route.body_field: assert isinstance(route.body_field, ModelField), ( "A request body must be a Pydantic Field" @@ -465,6 +480,8 @@ def get_fields_from_routes( responses_from_routes.append(route.response_field) if route.response_fields: responses_from_routes.extend(route.response_fields.values()) + if route.stream_item_field: + responses_from_routes.append(route.stream_item_field) if route.callbacks: callback_flat_models.extend(get_fields_from_routes(route.callbacks)) params = get_flat_params(route.dependant) diff --git a/fastapi/routing.py b/fastapi/routing.py index d17650a627..aec4b5c3d1 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -11,6 +11,7 @@ from collections.abc import ( Collection, Coroutine, Generator, + Iterator, Mapping, Sequence, ) @@ -42,6 +43,7 @@ from fastapi.dependencies.utils import ( get_dependant, get_flat_dependant, get_parameterless_sub_dependant, + get_stream_item_type, get_typed_return_annotation, solve_dependencies, ) @@ -66,7 +68,7 @@ from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.responses import JSONResponse, Response +from starlette.responses import JSONResponse, Response, StreamingResponse from starlette.routing import ( BaseRoute, Match, @@ -315,6 +317,24 @@ async def run_endpoint_function( return await run_in_threadpool(dependant.call, **values) +def _build_response_args( + *, status_code: int | None, solved_result: Any +) -> dict[str, Any]: + response_args: dict[str, Any] = { + "background": solved_result.background_tasks, + } + # If status_code was set, use it, otherwise use the default from the + # response class, in the case of redirect it's 307 + current_status_code = ( + status_code if status_code else solved_result.response.status_code + ) + if current_status_code is not None: + response_args["status_code"] = current_status_code + if solved_result.response.status_code: + response_args["status_code"] = solved_result.response.status_code + return response_args + + def get_request_handler( dependant: Dependant, body_field: ModelField | None = None, @@ -330,6 +350,8 @@ def get_request_handler( dependency_overrides_provider: Any | None = None, embed_body_fields: bool = False, strict_content_type: bool | DefaultPlaceholder = Default(True), + stream_item_field: ModelField | None = None, + is_json_stream: bool = False, ) -> Callable[[Request], Coroutine[Any, Any, Response]]: assert dependant.call is not None, "dependant.call must be a function" is_coroutine = dependant.is_coroutine_callable @@ -427,61 +449,115 @@ def get_request_handler( embed_body_fields=embed_body_fields, ) errors = solved_result.errors + assert dependant.call # For types if not errors: - raw_response = await run_endpoint_function( - dependant=dependant, - values=solved_result.values, - is_coroutine=is_coroutine, - ) - if isinstance(raw_response, Response): - if raw_response.background is None: - raw_response.background = solved_result.background_tasks - response = raw_response - else: - response_args: dict[str, Any] = { - "background": solved_result.background_tasks - } - # If status_code was set, use it, otherwise use the default from the - # response class, in the case of redirect it's 307 - current_status_code = ( - status_code if status_code else solved_result.response.status_code - ) - if current_status_code is not None: - response_args["status_code"] = current_status_code - if solved_result.response.status_code: - response_args["status_code"] = solved_result.response.status_code - # Use the fast path (dump_json) when no custom response - # class was set and a response field with a TypeAdapter - # exists. Serializes directly to JSON bytes via Pydantic's - # Rust core, skipping the intermediate Python dict + - # json.dumps() step. - use_dump_json = response_field is not None and isinstance( - response_class, DefaultPlaceholder - ) - content = await serialize_response( - field=response_field, - response_content=raw_response, - include=response_model_include, - exclude=response_model_exclude, - by_alias=response_model_by_alias, - exclude_unset=response_model_exclude_unset, - exclude_defaults=response_model_exclude_defaults, - exclude_none=response_model_exclude_none, - is_coroutine=is_coroutine, - endpoint_ctx=endpoint_ctx, - dump_json=use_dump_json, - ) - if use_dump_json: - response = Response( - content=content, - media_type="application/json", - **response_args, + if is_json_stream: + # Generator endpoint: stream as JSONL + gen = dependant.call(**solved_result.values) + + def _serialize_item(item: Any) -> bytes: + if stream_item_field: + value, errors = stream_item_field.validate( + item, {}, loc=("response",) + ) + if errors: + ctx = endpoint_ctx or EndpointContext() + raise ResponseValidationError( + errors=errors, + body=item, + endpoint_ctx=ctx, + ) + line = stream_item_field.serialize_json( + value, + include=response_model_include, + exclude=response_model_exclude, + by_alias=response_model_by_alias, + exclude_unset=response_model_exclude_unset, + exclude_defaults=response_model_exclude_defaults, + exclude_none=response_model_exclude_none, + ) + return line + b"\n" + else: + data = jsonable_encoder(item) + return json.dumps(data).encode("utf-8") + b"\n" + + if dependant.is_async_gen_callable: + + async def _async_stream_jsonl() -> AsyncIterator[bytes]: + async for item in gen: + yield _serialize_item(item) + + stream_content: AsyncIterator[bytes] | Iterator[bytes] = ( + _async_stream_jsonl() ) else: - response = actual_response_class(content, **response_args) - if not is_body_allowed_for_status_code(response.status_code): - response.body = b"" + + def _sync_stream_jsonl() -> Iterator[bytes]: + for item in gen: + yield _serialize_item(item) + + stream_content = _sync_stream_jsonl() + + response = StreamingResponse( + stream_content, + media_type="application/jsonl", + background=solved_result.background_tasks, + ) response.headers.raw.extend(solved_result.response.headers.raw) + elif dependant.is_async_gen_callable or dependant.is_gen_callable: + # Raw streaming with explicit response_class (e.g. StreamingResponse) + gen = dependant.call(**solved_result.values) + response_args = _build_response_args( + status_code=status_code, solved_result=solved_result + ) + response = actual_response_class(content=gen, **response_args) + response.headers.raw.extend(solved_result.response.headers.raw) + else: + raw_response = await run_endpoint_function( + dependant=dependant, + values=solved_result.values, + is_coroutine=is_coroutine, + ) + if isinstance(raw_response, Response): + if raw_response.background is None: + raw_response.background = solved_result.background_tasks + response = raw_response + else: + response_args = _build_response_args( + status_code=status_code, solved_result=solved_result + ) + # Use the fast path (dump_json) when no custom response + # class was set and a response field with a TypeAdapter + # exists. Serializes directly to JSON bytes via Pydantic's + # Rust core, skipping the intermediate Python dict + + # json.dumps() step. + use_dump_json = response_field is not None and isinstance( + response_class, DefaultPlaceholder + ) + content = await serialize_response( + field=response_field, + response_content=raw_response, + include=response_model_include, + exclude=response_model_exclude, + by_alias=response_model_by_alias, + exclude_unset=response_model_exclude_unset, + exclude_defaults=response_model_exclude_defaults, + exclude_none=response_model_exclude_none, + is_coroutine=is_coroutine, + endpoint_ctx=endpoint_ctx, + dump_json=use_dump_json, + ) + if use_dump_json: + response = Response( + content=content, + media_type="application/json", + **response_args, + ) + else: + response = actual_response_class(content, **response_args) + if not is_body_allowed_for_status_code(response.status_code): + response.body = b"" + response.headers.raw.extend(solved_result.response.headers.raw) if errors: validation_error = RequestValidationError( errors, body=body, endpoint_ctx=endpoint_ctx @@ -609,12 +685,21 @@ class APIRoute(routing.Route): ) -> None: self.path = path self.endpoint = endpoint + self.stream_item_type: Any | None = None if isinstance(response_model, DefaultPlaceholder): return_annotation = get_typed_return_annotation(endpoint) if lenient_issubclass(return_annotation, Response): response_model = None else: - response_model = return_annotation + stream_item = get_stream_item_type(return_annotation) + if stream_item is not None: + # Only extract item type for JSONL streaming when no + # explicit response_class (e.g. StreamingResponse) was set + if isinstance(response_class, DefaultPlaceholder): + self.stream_item_type = stream_item + response_model = None + else: + response_model = return_annotation self.response_model = response_model self.summary = summary self.response_description = response_description @@ -663,6 +748,15 @@ class APIRoute(routing.Route): ) else: self.response_field = None # type: ignore + if self.stream_item_type: + stream_item_name = "StreamItem_" + self.unique_id + self.stream_item_field: ModelField | None = create_model_field( + name=stream_item_name, + type_=self.stream_item_type, + mode="serialization", + ) + else: + self.stream_item_field = None self.dependencies = list(dependencies or []) self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "") # if a "form feed" character (page break) is found in the description text, @@ -704,6 +798,11 @@ class APIRoute(routing.Route): name=self.unique_id, embed_body_fields=self._embed_body_fields, ) + # Detect generator endpoints that should stream as JSONL + # (only when no explicit response_class like StreamingResponse is set) + self.is_json_stream = isinstance(response_class, DefaultPlaceholder) and ( + self.dependant.is_async_gen_callable or self.dependant.is_gen_callable + ) self.app = request_response(self.get_route_handler()) def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: @@ -722,6 +821,8 @@ class APIRoute(routing.Route): dependency_overrides_provider=self.dependency_overrides_provider, embed_body_fields=self._embed_body_fields, strict_content_type=self.strict_content_type, + stream_item_field=self.stream_item_field, + is_json_stream=self.is_json_stream, ) def matches(self, scope: Scope) -> tuple[Match, Scope]: