🐛 Fix Peewee contextvars handling in docs (#879)

This commit is contained in:
Sebastián Ramírez
2020-01-17 09:59:38 +01:00
committed by GitHub
parent e5d7878856
commit 3f53deebc9
3 changed files with 160 additions and 23 deletions

View File

@@ -3,22 +3,20 @@ from contextvars import ContextVar
import peewee
DATABASE_NAME = "test.db"
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())
class PeeweeConnectionState(peewee._ConnectionState):
def __init__(self, **kwargs):
super().__setattr__("_state", {})
self._state["closed"] = ContextVar("closed", default=True)
self._state["conn"] = ContextVar("conn", default=None)
self._state["ctx"] = ContextVar("ctx", default=[])
self._state["transactions"] = ContextVar("transactions", default=[])
super().__setattr__("_state", db_state)
super().__init__(**kwargs)
def __setattr__(self, name, value):
self._state[name].set(value)
self._state.get()[name] = value
def __getattr__(self, name):
return self._state[name].get()
return self._state.get()[name]
db = peewee.SqliteDatabase(DATABASE_NAME, check_same_thread=False)

View File

@@ -2,8 +2,10 @@ import time
from typing import List
from fastapi import Depends, FastAPI, HTTPException
from starlette.requests import Request
from . import crud, database, models, schemas
from .database import db_state_default
database.db.connect()
database.db.create_tables([models.User, models.Item])
@@ -11,8 +13,9 @@ database.db.close()
app = FastAPI()
sleep_time = 10
# Dependency
def get_db():
try:
database.db.connect()
@@ -22,6 +25,14 @@ def get_db():
database.db.close()
@app.middleware("http")
async def reset_db_middleware(request: Request, call_next):
database.db._state._state.set(db_state_default.copy())
database.db._state.reset()
response = await call_next(request)
return response
@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
db_user = crud.get_user_by_email(email=user.email)
@@ -65,6 +76,8 @@ def read_items(skip: int = 0, limit: int = 100):
"/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
time.sleep(15) # Fake long processing request
global sleep_time
sleep_time = max(0, sleep_time - 1)
time.sleep(sleep_time) # Fake long processing request
users = crud.get_users(skip=skip, limit=limit)
return users