mirror of
https://github.com/fastapi/fastapi.git
synced 2025-12-26 15:51:02 -05:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67f8cb3b4f | ||
|
|
5c2828bd13 | ||
|
|
b087246f26 | ||
|
|
219d299426 | ||
|
|
31da760729 |
2
Pipfile
2
Pipfile
@@ -26,7 +26,7 @@ uvicorn = "*"
|
||||
|
||||
[packages]
|
||||
starlette = "==0.12.0"
|
||||
pydantic = "==0.25.0"
|
||||
pydantic = "==0.26.0"
|
||||
databases = {extras = ["sqlite"],version = "*"}
|
||||
hypercorn = "*"
|
||||
|
||||
|
||||
8
Pipfile.lock
generated
8
Pipfile.lock
generated
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "8d23c38a8d3018315f49a2298e9098d8b5f248338dbba9e024246c9abb5949a2"
|
||||
"sha256": "4a33b47e814fa75533548874ffadbc6163b3058db4d1615ff633512366d72ccb"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
@@ -113,11 +113,11 @@
|
||||
},
|
||||
"pydantic": {
|
||||
"hashes": [
|
||||
"sha256:2203e01c1d87a3d964aa0db56efdb1b89a90eca610ab3f0ddea396e2a5fa4cc4",
|
||||
"sha256:ac207906e78b1cafbbff6d57b0ce51b989cf5361d2487013f0b353f3bb3b8442"
|
||||
"sha256:b72e0df2463cee746cf42639845d4106c19f30136375e779352d710e69617731",
|
||||
"sha256:dab99d3070e040b8b2e987dfbe237350ab92d5d57a22d4e0e268ede2d85c7964"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==0.25.0"
|
||||
"version": "==0.26.0"
|
||||
},
|
||||
"pytoml": {
|
||||
"hashes": [
|
||||
|
||||
@@ -1,5 +1,23 @@
|
||||
## Next release
|
||||
|
||||
## 0.24.0
|
||||
|
||||
* Add support for WebSockets with dependencies and parameters.
|
||||
* Support included for:
|
||||
* `Depends`
|
||||
* `Security`
|
||||
* `Cookie`
|
||||
* `Header`
|
||||
* `Path`
|
||||
* `Query`
|
||||
* ...as these are compatible with the WebSockets protocol (e.g. `Body` is not).
|
||||
* [Updated documentation for WebSockets](https://fastapi.tiangolo.com/tutorial/websockets/).
|
||||
* PR [#178](https://github.com/tiangolo/fastapi/pull/178) by [@jekirl](https://github.com/jekirl).
|
||||
|
||||
* Upgrade the compatible version of Pydantic to `0.26.0`.
|
||||
* This includes JSON Schema support for IP address and network objects, bug fixes, and other features.
|
||||
* PR [#247](https://github.com/tiangolo/fastapi/pull/247) by [@euri10](https://github.com/euri10).
|
||||
|
||||
## 0.23.0
|
||||
|
||||
* Upgrade the compatible version of Starlette to `0.12.0`.
|
||||
|
||||
0
docs/src/websockets/__init__.py
Normal file
0
docs/src/websockets/__init__.py
Normal file
@@ -44,10 +44,9 @@ async def get():
|
||||
return HTMLResponse(html)
|
||||
|
||||
|
||||
@app.websocket_route("/ws")
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
await websocket.send_text(f"Message text was: {data}")
|
||||
await websocket.close()
|
||||
|
||||
78
docs/src/websockets/tutorial002.py
Normal file
78
docs/src/websockets/tutorial002.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from fastapi import Cookie, Depends, FastAPI, Header
|
||||
from starlette.responses import HTMLResponse
|
||||
from starlette.status import WS_1008_POLICY_VIOLATION
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
html = """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Chat</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>WebSocket Chat</h1>
|
||||
<form action="" onsubmit="sendMessage(event)">
|
||||
<label>Item ID: <input type="text" id="itemId" autocomplete="off" value="foo"/></label>
|
||||
<button onclick="connect(event)">Connect</button>
|
||||
<br>
|
||||
<label>Message: <input type="text" id="messageText" autocomplete="off"/></label>
|
||||
<button>Send</button>
|
||||
</form>
|
||||
<ul id='messages'>
|
||||
</ul>
|
||||
<script>
|
||||
var ws = null;
|
||||
function connect(event) {
|
||||
var input = document.getElementById("itemId")
|
||||
ws = new WebSocket("ws://localhost:8000/items/" + input.value + "/ws");
|
||||
ws.onmessage = function(event) {
|
||||
var messages = document.getElementById('messages')
|
||||
var message = document.createElement('li')
|
||||
var content = document.createTextNode(event.data)
|
||||
message.appendChild(content)
|
||||
messages.appendChild(message)
|
||||
};
|
||||
}
|
||||
function sendMessage(event) {
|
||||
var input = document.getElementById("messageText")
|
||||
ws.send(input.value)
|
||||
input.value = ''
|
||||
event.preventDefault()
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get():
|
||||
return HTMLResponse(html)
|
||||
|
||||
|
||||
async def get_cookie_or_client(
|
||||
websocket: WebSocket, session: str = Cookie(None), x_client: str = Header(None)
|
||||
):
|
||||
if session is None and x_client is None:
|
||||
await websocket.close(code=WS_1008_POLICY_VIOLATION)
|
||||
return session or x_client
|
||||
|
||||
|
||||
@app.websocket("/items/{item_id}/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
item_id: int,
|
||||
q: str = None,
|
||||
cookie_or_client: str = Depends(get_cookie_or_client),
|
||||
):
|
||||
await websocket.accept()
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
await websocket.send_text(
|
||||
f"Session Cookie or X-Client Header value is: {cookie_or_client}"
|
||||
)
|
||||
if q is not None:
|
||||
await websocket.send_text(f"Query parameter q is: {q}")
|
||||
await websocket.send_text(f"Message text was: {data}, for item ID: {item_id}")
|
||||
@@ -27,9 +27,9 @@ But it's the simplest way to focus on the server-side of WebSockets and have a w
|
||||
{!./src/websockets/tutorial001.py!}
|
||||
```
|
||||
|
||||
## Create a `websocket_route`
|
||||
## Create a `websocket`
|
||||
|
||||
In your **FastAPI** application, create a `websocket_route`:
|
||||
In your **FastAPI** application, create a `websocket`:
|
||||
|
||||
```Python hl_lines="3 47 48"
|
||||
{!./src/websockets/tutorial001.py!}
|
||||
@@ -38,15 +38,6 @@ In your **FastAPI** application, create a `websocket_route`:
|
||||
!!! tip
|
||||
In this example we are importing `WebSocket` from `starlette.websockets` to use it in the type declaration in the WebSocket route function.
|
||||
|
||||
That is not required, but it's recommended as it will provide you completion and checks inside the function.
|
||||
|
||||
|
||||
!!! info
|
||||
This `websocket_route` we are using comes directly from <a href="https://www.starlette.io/applications/" target="_blank">Starlette</a>.
|
||||
|
||||
That's why the naming convention is not the same as with other API path operations (`get`, `post`, etc).
|
||||
|
||||
|
||||
## Await for messages and send messages
|
||||
|
||||
In your WebSocket route you can `await` for messages and send messages.
|
||||
@@ -57,6 +48,32 @@ In your WebSocket route you can `await` for messages and send messages.
|
||||
|
||||
You can receive and send binary, text, and JSON data.
|
||||
|
||||
## Using `Depends` and others
|
||||
|
||||
In WebSocket endpoints you can import from `fastapi` and use:
|
||||
|
||||
* `Depends`
|
||||
* `Security`
|
||||
* `Cookie`
|
||||
* `Header`
|
||||
* `Path`
|
||||
* `Query`
|
||||
|
||||
They work the same way as for other FastAPI endpoints/*path operations*:
|
||||
|
||||
```Python hl_lines="55 56 57 58 59 60 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78"
|
||||
{!./src/websockets/tutorial002.py!}
|
||||
```
|
||||
|
||||
!!! info
|
||||
In a WebSocket it doesn't really make sense to raise an `HTTPException`. So it's better to close the WebSocket connection directly.
|
||||
|
||||
You can use a closing code from the <a href="https://tools.ietf.org/html/rfc6455#section-7.4.1" target="_blank">valid codes defined in the specification</a>.
|
||||
|
||||
In the future, there will be a `WebSocketException` that you will be able to `raise` from anywhere, and add exception handlers for it. It depends on the <a href="https://github.com/encode/starlette/pull/527" target="_blank">PR #527</a> in Starlette.
|
||||
|
||||
## More info
|
||||
|
||||
To learn more about the options, check Starlette's documentation for:
|
||||
|
||||
* <a href="https://www.starlette.io/applications/" target="_blank">Applications (`websocket_route`)</a>.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
|
||||
|
||||
__version__ = "0.23.0"
|
||||
__version__ = "0.24.0"
|
||||
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
|
||||
@@ -203,6 +203,18 @@ class FastAPI(Starlette):
|
||||
|
||||
return decorator
|
||||
|
||||
def add_api_websocket_route(
|
||||
self, path: str, endpoint: Callable, name: str = None
|
||||
) -> None:
|
||||
self.router.add_api_websocket_route(path, endpoint, name=name)
|
||||
|
||||
def websocket(self, path: str, name: str = None) -> Callable:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self.add_api_websocket_route(path, func, name=name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def include_router(
|
||||
self,
|
||||
router: routing.APIRouter,
|
||||
|
||||
@@ -26,6 +26,7 @@ class Dependant:
|
||||
name: str = None,
|
||||
call: Callable = None,
|
||||
request_param_name: str = None,
|
||||
websocket_param_name: str = None,
|
||||
background_tasks_param_name: str = None,
|
||||
security_scopes_param_name: str = None,
|
||||
security_scopes: List[str] = None,
|
||||
@@ -38,6 +39,7 @@ class Dependant:
|
||||
self.dependencies = dependencies or []
|
||||
self.security_requirements = security_schemes or []
|
||||
self.request_param_name = request_param_name
|
||||
self.websocket_param_name = websocket_param_name
|
||||
self.background_tasks_param_name = background_tasks_param_name
|
||||
self.security_scopes = security_scopes
|
||||
self.security_scopes_param_name = security_scopes_param_name
|
||||
|
||||
@@ -33,6 +33,7 @@ from starlette.background import BackgroundTasks
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
||||
from starlette.requests import Request
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
param_supported_types = (
|
||||
str,
|
||||
@@ -184,6 +185,8 @@ def get_dependant(
|
||||
)
|
||||
elif lenient_issubclass(param.annotation, Request):
|
||||
dependant.request_param_name = param_name
|
||||
elif lenient_issubclass(param.annotation, WebSocket):
|
||||
dependant.websocket_param_name = param_name
|
||||
elif lenient_issubclass(param.annotation, BackgroundTasks):
|
||||
dependant.background_tasks_param_name = param_name
|
||||
elif lenient_issubclass(param.annotation, SecurityScopes):
|
||||
@@ -279,7 +282,7 @@ def is_coroutine_callable(call: Callable) -> bool:
|
||||
|
||||
async def solve_dependencies(
|
||||
*,
|
||||
request: Request,
|
||||
request: Union[Request, WebSocket],
|
||||
dependant: Dependant,
|
||||
body: Dict[str, Any] = None,
|
||||
background_tasks: BackgroundTasks = None,
|
||||
@@ -326,8 +329,10 @@ async def solve_dependencies(
|
||||
)
|
||||
values.update(body_values)
|
||||
errors.extend(body_errors)
|
||||
if dependant.request_param_name:
|
||||
if dependant.request_param_name and isinstance(request, Request):
|
||||
values[dependant.request_param_name] = request
|
||||
elif dependant.websocket_param_name and isinstance(request, WebSocket):
|
||||
values[dependant.websocket_param_name] = request
|
||||
if dependant.background_tasks_param_name:
|
||||
if background_tasks is None:
|
||||
background_tasks = BackgroundTasks()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
from fastapi import params
|
||||
@@ -21,8 +22,14 @@ from starlette.concurrency import run_in_threadpool
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.routing import compile_path, get_name, request_response
|
||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
||||
from starlette.routing import (
|
||||
compile_path,
|
||||
get_name,
|
||||
request_response,
|
||||
websocket_session,
|
||||
)
|
||||
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
|
||||
def serialize_response(*, field: Field = None, response: Response) -> Any:
|
||||
@@ -97,6 +104,35 @@ def get_app(
|
||||
return app
|
||||
|
||||
|
||||
def get_websocket_app(dependant: Dependant) -> Callable:
|
||||
async def app(websocket: WebSocket) -> None:
|
||||
values, errors, _ = await solve_dependencies(
|
||||
request=websocket, dependant=dependant
|
||||
)
|
||||
if errors:
|
||||
await websocket.close(code=WS_1008_POLICY_VIOLATION)
|
||||
errors_out = ValidationError(errors)
|
||||
raise HTTPException(
|
||||
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
|
||||
)
|
||||
assert dependant.call is not None, "dependant.call must me a function"
|
||||
await dependant.call(**values)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class APIWebSocketRoute(routing.WebSocketRoute):
|
||||
def __init__(self, path: str, endpoint: Callable, *, name: str = None) -> None:
|
||||
self.path = path
|
||||
self.endpoint = endpoint
|
||||
self.name = get_name(endpoint) if name is None else name
|
||||
self.dependant = get_dependant(path=path, call=self.endpoint)
|
||||
self.app = websocket_session(get_websocket_app(dependant=self.dependant))
|
||||
regex = "^" + path + "$"
|
||||
regex = re.sub("{([a-zA-Z_][a-zA-Z0-9_]*)}", r"(?P<\1>[^/]+)", regex)
|
||||
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
|
||||
|
||||
|
||||
class APIRoute(routing.Route):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -281,6 +317,19 @@ class APIRouter(routing.Router):
|
||||
|
||||
return decorator
|
||||
|
||||
def add_api_websocket_route(
|
||||
self, path: str, endpoint: Callable, name: str = None
|
||||
) -> None:
|
||||
route = APIWebSocketRoute(path, endpoint=endpoint, name=name)
|
||||
self.routes.append(route)
|
||||
|
||||
def websocket(self, path: str, name: str = None) -> Callable:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self.add_api_websocket_route(path, func, name=name)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def include_router(
|
||||
self,
|
||||
router: "APIRouter",
|
||||
@@ -326,6 +375,10 @@ class APIRouter(routing.Router):
|
||||
include_in_schema=route.include_in_schema,
|
||||
name=route.name,
|
||||
)
|
||||
elif isinstance(route, APIWebSocketRoute):
|
||||
self.add_api_websocket_route(
|
||||
prefix + route.path, route.endpoint, name=route.name
|
||||
)
|
||||
elif isinstance(route, routing.WebSocketRoute):
|
||||
self.add_websocket_route(
|
||||
prefix + route.path, route.endpoint, name=route.name
|
||||
|
||||
@@ -20,7 +20,7 @@ classifiers = [
|
||||
]
|
||||
requires = [
|
||||
"starlette >=0.11.1,<=0.12.0",
|
||||
"pydantic >=0.17,<=0.25.0"
|
||||
"pydantic >=0.17,<=0.26.0"
|
||||
]
|
||||
description-file = "README.md"
|
||||
requires-python = ">=3.6"
|
||||
|
||||
0
tests/test_tutorial/test_websockets/__init__.py
Normal file
0
tests/test_tutorial/test_websockets/__init__.py
Normal file
25
tests/test_tutorial/test_websockets/test_tutorial001.py
Normal file
25
tests/test_tutorial/test_websockets/test_tutorial001.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
from websockets.tutorial001 import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_main():
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert b"<!DOCTYPE html>" in response.content
|
||||
|
||||
|
||||
def test_websocket():
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
with client.websocket_connect("/ws") as websocket:
|
||||
message = "Message one"
|
||||
websocket.send_text(message)
|
||||
data = websocket.receive_text()
|
||||
assert data == f"Message text was: {message}"
|
||||
message = "Message two"
|
||||
websocket.send_text(message)
|
||||
data = websocket.receive_text()
|
||||
assert data == f"Message text was: {message}"
|
||||
83
tests/test_tutorial/test_websockets/test_tutorial002.py
Normal file
83
tests/test_tutorial/test_websockets/test_tutorial002.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
from websockets.tutorial002 import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_main():
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert b"<!DOCTYPE html>" in response.content
|
||||
|
||||
|
||||
def test_websocket_with_cookie():
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
with client.websocket_connect(
|
||||
"/items/1/ws", cookies={"session": "fakesession"}
|
||||
) as websocket:
|
||||
message = "Message one"
|
||||
websocket.send_text(message)
|
||||
data = websocket.receive_text()
|
||||
assert data == "Session Cookie or X-Client Header value is: fakesession"
|
||||
data = websocket.receive_text()
|
||||
assert data == f"Message text was: {message}, for item ID: 1"
|
||||
message = "Message two"
|
||||
websocket.send_text(message)
|
||||
data = websocket.receive_text()
|
||||
assert data == "Session Cookie or X-Client Header value is: fakesession"
|
||||
data = websocket.receive_text()
|
||||
assert data == f"Message text was: {message}, for item ID: 1"
|
||||
|
||||
|
||||
def test_websocket_with_header():
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
with client.websocket_connect(
|
||||
"/items/2/ws", headers={"X-Client": "xmen"}
|
||||
) as websocket:
|
||||
message = "Message one"
|
||||
websocket.send_text(message)
|
||||
data = websocket.receive_text()
|
||||
assert data == "Session Cookie or X-Client Header value is: xmen"
|
||||
data = websocket.receive_text()
|
||||
assert data == f"Message text was: {message}, for item ID: 2"
|
||||
message = "Message two"
|
||||
websocket.send_text(message)
|
||||
data = websocket.receive_text()
|
||||
assert data == "Session Cookie or X-Client Header value is: xmen"
|
||||
data = websocket.receive_text()
|
||||
assert data == f"Message text was: {message}, for item ID: 2"
|
||||
|
||||
|
||||
def test_websocket_with_header_and_query():
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
with client.websocket_connect(
|
||||
"/items/2/ws?q=baz", headers={"X-Client": "xmen"}
|
||||
) as websocket:
|
||||
message = "Message one"
|
||||
websocket.send_text(message)
|
||||
data = websocket.receive_text()
|
||||
assert data == "Session Cookie or X-Client Header value is: xmen"
|
||||
data = websocket.receive_text()
|
||||
assert data == "Query parameter q is: baz"
|
||||
data = websocket.receive_text()
|
||||
assert data == f"Message text was: {message}, for item ID: 2"
|
||||
message = "Message two"
|
||||
websocket.send_text(message)
|
||||
data = websocket.receive_text()
|
||||
assert data == "Session Cookie or X-Client Header value is: xmen"
|
||||
data = websocket.receive_text()
|
||||
assert data == "Query parameter q is: baz"
|
||||
data = websocket.receive_text()
|
||||
assert data == f"Message text was: {message}, for item ID: 2"
|
||||
|
||||
|
||||
def test_websocket_no_credentials():
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
client.websocket_connect("/items/2/ws")
|
||||
|
||||
|
||||
def test_websocket_invalid_data():
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
client.websocket_connect("/items/foo/ws", headers={"X-Client": "xmen"})
|
||||
@@ -28,6 +28,13 @@ async def routerprefixindex(websocket: WebSocket):
|
||||
await websocket.close()
|
||||
|
||||
|
||||
@router.websocket("/router2")
|
||||
async def routerindex(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.send_text("Hello, router!")
|
||||
await websocket.close()
|
||||
|
||||
|
||||
app.include_router(router)
|
||||
app.include_router(prefix_router, prefix="/prefix")
|
||||
|
||||
@@ -51,3 +58,10 @@ def test_prefix_router():
|
||||
with client.websocket_connect("/prefix/") as websocket:
|
||||
data = websocket.receive_text()
|
||||
assert data == "Hello, router with prefix!"
|
||||
|
||||
|
||||
def test_router2():
|
||||
client = TestClient(app)
|
||||
with client.websocket_connect("/router2") as websocket:
|
||||
data = websocket.receive_text()
|
||||
assert data == "Hello, router!"
|
||||
|
||||
Reference in New Issue
Block a user