mirror of
https://github.com/fastapi/fastapi.git
synced 2026-02-27 20:29:48 -05:00
✨ Implement support for JSON Lines and Streaming data with yield
This commit is contained in:
@@ -1,7 +1,17 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from collections.abc import (
|
||||
AsyncGenerator,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Sequence,
|
||||
)
|
||||
from contextlib import AsyncExitStack, contextmanager
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
@@ -251,6 +261,26 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
||||
return get_typed_annotation(annotation, globalns)
|
||||
|
||||
|
||||
_STREAM_ORIGINS = {
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
AsyncGenerator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Generator,
|
||||
}
|
||||
|
||||
|
||||
def get_stream_item_type(annotation: Any) -> Any | None:
|
||||
origin = get_origin(annotation)
|
||||
if origin is not None and origin in _STREAM_ORIGINS:
|
||||
type_args = get_args(annotation)
|
||||
if type_args:
|
||||
return type_args[0]
|
||||
return Any
|
||||
return None
|
||||
|
||||
|
||||
def get_dependant(
|
||||
*,
|
||||
path: str,
|
||||
|
||||
@@ -355,25 +355,40 @@ def get_openapi_path(
|
||||
operation.setdefault("responses", {}).setdefault(status_code, {})[
|
||||
"description"
|
||||
] = route.response_description
|
||||
if route_response_media_type and is_body_allowed_for_status_code(
|
||||
route.status_code
|
||||
):
|
||||
response_schema = {"type": "string"}
|
||||
if lenient_issubclass(current_response_class, JSONResponse):
|
||||
if route.response_field:
|
||||
response_schema = get_schema_from_model_field(
|
||||
field=route.response_field,
|
||||
if is_body_allowed_for_status_code(route.status_code):
|
||||
# Check for JSONL streaming (generator endpoints)
|
||||
if route.is_json_stream:
|
||||
jsonl_content: dict[str, Any] = {}
|
||||
if route.stream_item_field:
|
||||
item_schema = get_schema_from_model_field(
|
||||
field=route.stream_item_field,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
jsonl_content["itemSchema"] = item_schema
|
||||
else:
|
||||
response_schema = {}
|
||||
operation.setdefault("responses", {}).setdefault(
|
||||
status_code, {}
|
||||
).setdefault("content", {}).setdefault(route_response_media_type, {})[
|
||||
"schema"
|
||||
] = response_schema
|
||||
jsonl_content["itemSchema"] = {}
|
||||
operation.setdefault("responses", {}).setdefault(
|
||||
status_code, {}
|
||||
).setdefault("content", {})["application/jsonl"] = jsonl_content
|
||||
elif route_response_media_type:
|
||||
response_schema = {"type": "string"}
|
||||
if lenient_issubclass(current_response_class, JSONResponse):
|
||||
if route.response_field:
|
||||
response_schema = get_schema_from_model_field(
|
||||
field=route.response_field,
|
||||
model_name_map=model_name_map,
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
else:
|
||||
response_schema = {}
|
||||
operation.setdefault("responses", {}).setdefault(
|
||||
status_code, {}
|
||||
).setdefault("content", {}).setdefault(
|
||||
route_response_media_type, {}
|
||||
)["schema"] = response_schema
|
||||
if route.responses:
|
||||
operation_responses = operation.setdefault("responses", {})
|
||||
for (
|
||||
@@ -453,9 +468,9 @@ def get_fields_from_routes(
|
||||
request_fields_from_routes: list[ModelField] = []
|
||||
callback_flat_models: list[ModelField] = []
|
||||
for route in routes:
|
||||
if getattr(route, "include_in_schema", None) and isinstance(
|
||||
route, routing.APIRoute
|
||||
):
|
||||
if not isinstance(route, routing.APIRoute):
|
||||
continue
|
||||
if route.include_in_schema:
|
||||
if route.body_field:
|
||||
assert isinstance(route.body_field, ModelField), (
|
||||
"A request body must be a Pydantic Field"
|
||||
@@ -465,6 +480,8 @@ def get_fields_from_routes(
|
||||
responses_from_routes.append(route.response_field)
|
||||
if route.response_fields:
|
||||
responses_from_routes.extend(route.response_fields.values())
|
||||
if route.stream_item_field:
|
||||
responses_from_routes.append(route.stream_item_field)
|
||||
if route.callbacks:
|
||||
callback_flat_models.extend(get_fields_from_routes(route.callbacks))
|
||||
params = get_flat_params(route.dependant)
|
||||
|
||||
@@ -11,6 +11,7 @@ from collections.abc import (
|
||||
Collection,
|
||||
Coroutine,
|
||||
Generator,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Sequence,
|
||||
)
|
||||
@@ -42,6 +43,7 @@ from fastapi.dependencies.utils import (
|
||||
get_dependant,
|
||||
get_flat_dependant,
|
||||
get_parameterless_sub_dependant,
|
||||
get_stream_item_type,
|
||||
get_typed_return_annotation,
|
||||
solve_dependencies,
|
||||
)
|
||||
@@ -66,7 +68,7 @@ from starlette._utils import is_async_callable
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.responses import JSONResponse, Response, StreamingResponse
|
||||
from starlette.routing import (
|
||||
BaseRoute,
|
||||
Match,
|
||||
@@ -315,6 +317,24 @@ async def run_endpoint_function(
|
||||
return await run_in_threadpool(dependant.call, **values)
|
||||
|
||||
|
||||
def _build_response_args(
|
||||
*, status_code: int | None, solved_result: Any
|
||||
) -> dict[str, Any]:
|
||||
response_args: dict[str, Any] = {
|
||||
"background": solved_result.background_tasks,
|
||||
}
|
||||
# If status_code was set, use it, otherwise use the default from the
|
||||
# response class, in the case of redirect it's 307
|
||||
current_status_code = (
|
||||
status_code if status_code else solved_result.response.status_code
|
||||
)
|
||||
if current_status_code is not None:
|
||||
response_args["status_code"] = current_status_code
|
||||
if solved_result.response.status_code:
|
||||
response_args["status_code"] = solved_result.response.status_code
|
||||
return response_args
|
||||
|
||||
|
||||
def get_request_handler(
|
||||
dependant: Dependant,
|
||||
body_field: ModelField | None = None,
|
||||
@@ -330,6 +350,8 @@ def get_request_handler(
|
||||
dependency_overrides_provider: Any | None = None,
|
||||
embed_body_fields: bool = False,
|
||||
strict_content_type: bool | DefaultPlaceholder = Default(True),
|
||||
stream_item_field: ModelField | None = None,
|
||||
is_json_stream: bool = False,
|
||||
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
assert dependant.call is not None, "dependant.call must be a function"
|
||||
is_coroutine = dependant.is_coroutine_callable
|
||||
@@ -427,61 +449,115 @@ def get_request_handler(
|
||||
embed_body_fields=embed_body_fields,
|
||||
)
|
||||
errors = solved_result.errors
|
||||
assert dependant.call # For types
|
||||
if not errors:
|
||||
raw_response = await run_endpoint_function(
|
||||
dependant=dependant,
|
||||
values=solved_result.values,
|
||||
is_coroutine=is_coroutine,
|
||||
)
|
||||
if isinstance(raw_response, Response):
|
||||
if raw_response.background is None:
|
||||
raw_response.background = solved_result.background_tasks
|
||||
response = raw_response
|
||||
else:
|
||||
response_args: dict[str, Any] = {
|
||||
"background": solved_result.background_tasks
|
||||
}
|
||||
# If status_code was set, use it, otherwise use the default from the
|
||||
# response class, in the case of redirect it's 307
|
||||
current_status_code = (
|
||||
status_code if status_code else solved_result.response.status_code
|
||||
)
|
||||
if current_status_code is not None:
|
||||
response_args["status_code"] = current_status_code
|
||||
if solved_result.response.status_code:
|
||||
response_args["status_code"] = solved_result.response.status_code
|
||||
# Use the fast path (dump_json) when no custom response
|
||||
# class was set and a response field with a TypeAdapter
|
||||
# exists. Serializes directly to JSON bytes via Pydantic's
|
||||
# Rust core, skipping the intermediate Python dict +
|
||||
# json.dumps() step.
|
||||
use_dump_json = response_field is not None and isinstance(
|
||||
response_class, DefaultPlaceholder
|
||||
)
|
||||
content = await serialize_response(
|
||||
field=response_field,
|
||||
response_content=raw_response,
|
||||
include=response_model_include,
|
||||
exclude=response_model_exclude,
|
||||
by_alias=response_model_by_alias,
|
||||
exclude_unset=response_model_exclude_unset,
|
||||
exclude_defaults=response_model_exclude_defaults,
|
||||
exclude_none=response_model_exclude_none,
|
||||
is_coroutine=is_coroutine,
|
||||
endpoint_ctx=endpoint_ctx,
|
||||
dump_json=use_dump_json,
|
||||
)
|
||||
if use_dump_json:
|
||||
response = Response(
|
||||
content=content,
|
||||
media_type="application/json",
|
||||
**response_args,
|
||||
if is_json_stream:
|
||||
# Generator endpoint: stream as JSONL
|
||||
gen = dependant.call(**solved_result.values)
|
||||
|
||||
def _serialize_item(item: Any) -> bytes:
|
||||
if stream_item_field:
|
||||
value, errors = stream_item_field.validate(
|
||||
item, {}, loc=("response",)
|
||||
)
|
||||
if errors:
|
||||
ctx = endpoint_ctx or EndpointContext()
|
||||
raise ResponseValidationError(
|
||||
errors=errors,
|
||||
body=item,
|
||||
endpoint_ctx=ctx,
|
||||
)
|
||||
line = stream_item_field.serialize_json(
|
||||
value,
|
||||
include=response_model_include,
|
||||
exclude=response_model_exclude,
|
||||
by_alias=response_model_by_alias,
|
||||
exclude_unset=response_model_exclude_unset,
|
||||
exclude_defaults=response_model_exclude_defaults,
|
||||
exclude_none=response_model_exclude_none,
|
||||
)
|
||||
return line + b"\n"
|
||||
else:
|
||||
data = jsonable_encoder(item)
|
||||
return json.dumps(data).encode("utf-8") + b"\n"
|
||||
|
||||
if dependant.is_async_gen_callable:
|
||||
|
||||
async def _async_stream_jsonl() -> AsyncIterator[bytes]:
|
||||
async for item in gen:
|
||||
yield _serialize_item(item)
|
||||
|
||||
stream_content: AsyncIterator[bytes] | Iterator[bytes] = (
|
||||
_async_stream_jsonl()
|
||||
)
|
||||
else:
|
||||
response = actual_response_class(content, **response_args)
|
||||
if not is_body_allowed_for_status_code(response.status_code):
|
||||
response.body = b""
|
||||
|
||||
def _sync_stream_jsonl() -> Iterator[bytes]:
|
||||
for item in gen:
|
||||
yield _serialize_item(item)
|
||||
|
||||
stream_content = _sync_stream_jsonl()
|
||||
|
||||
response = StreamingResponse(
|
||||
stream_content,
|
||||
media_type="application/jsonl",
|
||||
background=solved_result.background_tasks,
|
||||
)
|
||||
response.headers.raw.extend(solved_result.response.headers.raw)
|
||||
elif dependant.is_async_gen_callable or dependant.is_gen_callable:
|
||||
# Raw streaming with explicit response_class (e.g. StreamingResponse)
|
||||
gen = dependant.call(**solved_result.values)
|
||||
response_args = _build_response_args(
|
||||
status_code=status_code, solved_result=solved_result
|
||||
)
|
||||
response = actual_response_class(content=gen, **response_args)
|
||||
response.headers.raw.extend(solved_result.response.headers.raw)
|
||||
else:
|
||||
raw_response = await run_endpoint_function(
|
||||
dependant=dependant,
|
||||
values=solved_result.values,
|
||||
is_coroutine=is_coroutine,
|
||||
)
|
||||
if isinstance(raw_response, Response):
|
||||
if raw_response.background is None:
|
||||
raw_response.background = solved_result.background_tasks
|
||||
response = raw_response
|
||||
else:
|
||||
response_args = _build_response_args(
|
||||
status_code=status_code, solved_result=solved_result
|
||||
)
|
||||
# Use the fast path (dump_json) when no custom response
|
||||
# class was set and a response field with a TypeAdapter
|
||||
# exists. Serializes directly to JSON bytes via Pydantic's
|
||||
# Rust core, skipping the intermediate Python dict +
|
||||
# json.dumps() step.
|
||||
use_dump_json = response_field is not None and isinstance(
|
||||
response_class, DefaultPlaceholder
|
||||
)
|
||||
content = await serialize_response(
|
||||
field=response_field,
|
||||
response_content=raw_response,
|
||||
include=response_model_include,
|
||||
exclude=response_model_exclude,
|
||||
by_alias=response_model_by_alias,
|
||||
exclude_unset=response_model_exclude_unset,
|
||||
exclude_defaults=response_model_exclude_defaults,
|
||||
exclude_none=response_model_exclude_none,
|
||||
is_coroutine=is_coroutine,
|
||||
endpoint_ctx=endpoint_ctx,
|
||||
dump_json=use_dump_json,
|
||||
)
|
||||
if use_dump_json:
|
||||
response = Response(
|
||||
content=content,
|
||||
media_type="application/json",
|
||||
**response_args,
|
||||
)
|
||||
else:
|
||||
response = actual_response_class(content, **response_args)
|
||||
if not is_body_allowed_for_status_code(response.status_code):
|
||||
response.body = b""
|
||||
response.headers.raw.extend(solved_result.response.headers.raw)
|
||||
if errors:
|
||||
validation_error = RequestValidationError(
|
||||
errors, body=body, endpoint_ctx=endpoint_ctx
|
||||
@@ -609,12 +685,21 @@ class APIRoute(routing.Route):
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
self.stream_item_type: Any | None = None
|
||||
if isinstance(response_model, DefaultPlaceholder):
|
||||
return_annotation = get_typed_return_annotation(endpoint)
|
||||
if lenient_issubclass(return_annotation, Response):
|
||||
response_model = None
|
||||
else:
|
||||
response_model = return_annotation
|
||||
stream_item = get_stream_item_type(return_annotation)
|
||||
if stream_item is not None:
|
||||
# Only extract item type for JSONL streaming when no
|
||||
# explicit response_class (e.g. StreamingResponse) was set
|
||||
if isinstance(response_class, DefaultPlaceholder):
|
||||
self.stream_item_type = stream_item
|
||||
response_model = None
|
||||
else:
|
||||
response_model = return_annotation
|
||||
self.response_model = response_model
|
||||
self.summary = summary
|
||||
self.response_description = response_description
|
||||
@@ -663,6 +748,15 @@ class APIRoute(routing.Route):
|
||||
)
|
||||
else:
|
||||
self.response_field = None # type: ignore
|
||||
if self.stream_item_type:
|
||||
stream_item_name = "StreamItem_" + self.unique_id
|
||||
self.stream_item_field: ModelField | None = create_model_field(
|
||||
name=stream_item_name,
|
||||
type_=self.stream_item_type,
|
||||
mode="serialization",
|
||||
)
|
||||
else:
|
||||
self.stream_item_field = None
|
||||
self.dependencies = list(dependencies or [])
|
||||
self.description = description or inspect.cleandoc(self.endpoint.__doc__ or "")
|
||||
# if a "form feed" character (page break) is found in the description text,
|
||||
@@ -704,6 +798,11 @@ class APIRoute(routing.Route):
|
||||
name=self.unique_id,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
)
|
||||
# Detect generator endpoints that should stream as JSONL
|
||||
# (only when no explicit response_class like StreamingResponse is set)
|
||||
self.is_json_stream = isinstance(response_class, DefaultPlaceholder) and (
|
||||
self.dependant.is_async_gen_callable or self.dependant.is_gen_callable
|
||||
)
|
||||
self.app = request_response(self.get_route_handler())
|
||||
|
||||
def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]:
|
||||
@@ -722,6 +821,8 @@ class APIRoute(routing.Route):
|
||||
dependency_overrides_provider=self.dependency_overrides_provider,
|
||||
embed_body_fields=self._embed_body_fields,
|
||||
strict_content_type=self.strict_content_type,
|
||||
stream_item_field=self.stream_item_field,
|
||||
is_json_stream=self.is_json_stream,
|
||||
)
|
||||
|
||||
def matches(self, scope: Scope) -> tuple[Match, Scope]:
|
||||
|
||||
Reference in New Issue
Block a user