mirror of
https://github.com/fastapi/fastapi.git
synced 2026-01-27 23:39:17 -05:00
🐛 Fix Peewee contextvars handling in docs (#879)
This commit is contained in:
committed by
GitHub
parent
e5d7878856
commit
3f53deebc9
@@ -3,22 +3,20 @@ from contextvars import ContextVar
|
||||
import peewee
|
||||
|
||||
DATABASE_NAME = "test.db"
|
||||
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
|
||||
db_state = ContextVar("db_state", default=db_state_default.copy())
|
||||
|
||||
|
||||
class PeeweeConnectionState(peewee._ConnectionState):
|
||||
def __init__(self, **kwargs):
|
||||
super().__setattr__("_state", {})
|
||||
self._state["closed"] = ContextVar("closed", default=True)
|
||||
self._state["conn"] = ContextVar("conn", default=None)
|
||||
self._state["ctx"] = ContextVar("ctx", default=[])
|
||||
self._state["transactions"] = ContextVar("transactions", default=[])
|
||||
super().__setattr__("_state", db_state)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
self._state[name].set(value)
|
||||
self._state.get()[name] = value
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self._state[name].get()
|
||||
return self._state.get()[name]
|
||||
|
||||
|
||||
db = peewee.SqliteDatabase(DATABASE_NAME, check_same_thread=False)
|
||||
|
||||
@@ -2,8 +2,10 @@ import time
|
||||
from typing import List
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException
|
||||
from starlette.requests import Request
|
||||
|
||||
from . import crud, database, models, schemas
|
||||
from .database import db_state_default
|
||||
|
||||
database.db.connect()
|
||||
database.db.create_tables([models.User, models.Item])
|
||||
@@ -11,8 +13,9 @@ database.db.close()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
sleep_time = 10
|
||||
|
||||
|
||||
# Dependency
|
||||
def get_db():
|
||||
try:
|
||||
database.db.connect()
|
||||
@@ -22,6 +25,14 @@ def get_db():
|
||||
database.db.close()
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def reset_db_middleware(request: Request, call_next):
|
||||
database.db._state._state.set(db_state_default.copy())
|
||||
database.db._state.reset()
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
|
||||
def create_user(user: schemas.UserCreate):
|
||||
db_user = crud.get_user_by_email(email=user.email)
|
||||
@@ -65,6 +76,8 @@ def read_items(skip: int = 0, limit: int = 100):
|
||||
"/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
|
||||
)
|
||||
def read_slow_users(skip: int = 0, limit: int = 100):
|
||||
time.sleep(15) # Fake long processing request
|
||||
global sleep_time
|
||||
sleep_time = max(0, sleep_time - 1)
|
||||
time.sleep(sleep_time) # Fake long processing request
|
||||
users = crud.get_users(skip=skip, limit=limit)
|
||||
return users
|
||||
|
||||
Reference in New Issue
Block a user