From 9a76f0176aaa1e860058a3be8b6434bfd05feec8 Mon Sep 17 00:00:00 2001 From: Erik Vroon Date: Fri, 19 May 2023 19:51:16 +0200 Subject: [PATCH] Fix dummy insertions in `create-dev-db` cmd (#221) --- backend/bracket/database.py | 7 +- backend/bracket/utils/db.py | 20 ++++ backend/bracket/utils/dummy_records.py | 21 ---- backend/cli.py | 128 ++++++++++++++++++++----- backend/tests/integration_tests/sql.py | 16 +--- 5 files changed, 130 insertions(+), 62 deletions(-) diff --git a/backend/bracket/database.py b/backend/bracket/database.py index bc682277..9eb7aad5 100644 --- a/backend/bracket/database.py +++ b/backend/bracket/database.py @@ -13,7 +13,7 @@ database = Database(config.pg_dsn) engine = sqlalchemy.create_engine(config.pg_dsn) -async def init_db_when_empty() -> None: +async def init_db_when_empty() -> int | None: table_count = await database.fetch_val( 'SELECT count(*) FROM information_schema.tables WHERE table_schema = \'public\'' ) @@ -33,4 +33,7 @@ async def init_db_when_empty() -> None: password_hash=pwd_context.hash(config.admin_password), created=datetime_utc.now(), ) - await database.execute(query=users.insert(), values=admin.dict()) + user_id: int = await database.execute(query=users.insert(), values=admin.dict()) + return user_id + + return None diff --git a/backend/bracket/utils/db.py b/backend/bracket/utils/db.py index 87a8fe43..be5c1bf1 100644 --- a/backend/bracket/utils/db.py +++ b/backend/bracket/utils/db.py @@ -1,8 +1,11 @@ from typing import Type from databases import Database +from sqlalchemy import Table from sqlalchemy.sql import Select +from bracket.utils.conversion import to_string_mapping +from bracket.utils.logging import logger from bracket.utils.types import BaseModelT, assert_some @@ -24,3 +27,20 @@ async def fetch_all_parsed( ) -> list[BaseModelT]: records = await database.fetch_all(query) return [model.parse_obj(record._mapping) for record in records] + + +async def insert_generic( + database: Database, data_model: BaseModelT, table: Table, return_type: Type[BaseModelT] +) -> tuple[int, BaseModelT]: + try: + last_record_id: int = await database.execute( + query=table.insert(), values=to_string_mapping(data_model) # type: ignore[arg-type] + ) + row_inserted = await fetch_one_parsed( + database, return_type, table.select().where(table.c.id == last_record_id) + ) + assert isinstance(row_inserted, return_type) + return last_record_id, row_inserted + except Exception: + logger.exception(f'Could not insert {type(data_model).__name__}') + raise diff --git a/backend/bracket/utils/dummy_records.py b/backend/bracket/utils/dummy_records.py index 90b33d3c..6cd54785 100644 --- a/backend/bracket/utils/dummy_records.py +++ b/backend/bracket/utils/dummy_records.py @@ -211,24 +211,3 @@ DUMMY_USER_X_CLUB = UserXClub( user_id=DB_PLACEHOLDER_ID, club_id=DB_PLACEHOLDER_ID, ) - - -DUMMY_CLUBS = [DUMMY_CLUB] -DUMMY_TOURNAMENTS = [DUMMY_TOURNAMENT] -DUMMY_STAGES = [DUMMY_STAGE1, DUMMY_STAGE2] -DUMMY_ROUNDS = [DUMMY_ROUND1, DUMMY_ROUND2, DUMMY_ROUND3] -DUMMY_MATCHES = [DUMMY_MATCH1, DUMMY_MATCH2, DUMMY_MATCH3, DUMMY_MATCH4] -DUMMY_USERS = [DUMMY_USER] -DUMMY_TEAMS = [DUMMY_TEAM1, DUMMY_TEAM2, DUMMY_TEAM3, DUMMY_TEAM4] -DUMMY_PLAYERS = [ - DUMMY_PLAYER1, - DUMMY_PLAYER2, - DUMMY_PLAYER3, - DUMMY_PLAYER4, - DUMMY_PLAYER5, - DUMMY_PLAYER6, - DUMMY_PLAYER7, - DUMMY_PLAYER8, - DUMMY_PLAYER9, -] -DUMMY_USERS_X_CLUBS = [DUMMY_USER_X_CLUB] diff --git a/backend/cli.py b/backend/cli.py index c91046ef..26a9c46e 100755 --- a/backend/cli.py +++ b/backend/cli.py @@ -10,6 +10,15 @@ from bracket.config import Environment, environment from bracket.database import database, engine, init_db_when_empty from bracket.logger import get_logger from bracket.logic.elo import recalculate_elo_for_tournament_id +from bracket.models.db.club import Club +from bracket.models.db.match import Match +from bracket.models.db.player import Player +from bracket.models.db.round import Round +from bracket.models.db.stage import Stage +from bracket.models.db.team import Team +from bracket.models.db.tournament import Tournament +from bracket.models.db.user import User +from bracket.models.db.user_x_club import UserXClub from bracket.schema import ( clubs, matches, @@ -22,17 +31,33 @@ from bracket.schema import ( users, users_x_clubs, ) -from bracket.utils.conversion import to_string_mapping +from bracket.utils.db import insert_generic from bracket.utils.dummy_records import ( - DUMMY_CLUBS, - DUMMY_MATCHES, - DUMMY_PLAYERS, - DUMMY_ROUNDS, - DUMMY_STAGES, - DUMMY_TEAMS, - DUMMY_TOURNAMENTS, - DUMMY_USERS, - DUMMY_USERS_X_CLUBS, + DUMMY_CLUB, + DUMMY_MATCH1, + DUMMY_MATCH2, + DUMMY_MATCH3, + DUMMY_MATCH4, + DUMMY_PLAYER1, + DUMMY_PLAYER2, + DUMMY_PLAYER3, + DUMMY_PLAYER4, + DUMMY_PLAYER5, + DUMMY_PLAYER6, + DUMMY_PLAYER7, + DUMMY_PLAYER8, + DUMMY_PLAYER9, + DUMMY_ROUND1, + DUMMY_ROUND2, + DUMMY_ROUND3, + DUMMY_STAGE1, + DUMMY_STAGE2, + DUMMY_TEAM1, + DUMMY_TEAM2, + DUMMY_TEAM3, + DUMMY_TEAM4, + DUMMY_TOURNAMENT, + DUMMY_USER, ) from bracket.utils.types import BaseModelT @@ -67,11 +92,6 @@ def cli() -> None: pass -async def bulk_insert(table: Table, rows: list[BaseModelT]) -> None: - for row in rows: - await database.execute(query=table.insert(), values=to_string_mapping(row)) # type: ignore[arg-type] - - @click.command() @run_async async def create_dev_db() -> None: @@ -80,17 +100,75 @@ async def create_dev_db() -> None: logger.warning('Initializing database with dummy records') await database.connect() metadata.drop_all(engine) - await init_db_when_empty() + real_user_id = await init_db_when_empty() - await bulk_insert(users, DUMMY_USERS) - await bulk_insert(clubs, DUMMY_CLUBS) - await bulk_insert(tournaments, DUMMY_TOURNAMENTS) - await bulk_insert(stages, DUMMY_STAGES) - await bulk_insert(teams, DUMMY_TEAMS) - await bulk_insert(players, DUMMY_PLAYERS) - await bulk_insert(rounds, DUMMY_ROUNDS) - await bulk_insert(matches, DUMMY_MATCHES) - await bulk_insert(users_x_clubs, DUMMY_USERS_X_CLUBS) + table_lookup: dict[type, Table] = { + User: users, + Club: clubs, + Stage: stages, + Team: teams, + UserXClub: users_x_clubs, + Player: players, + Round: rounds, + Match: matches, + Tournament: tournaments, + } + + async def insert_dummy(obj_to_insert: BaseModelT) -> int: + record_id, _ = await insert_generic( + database, obj_to_insert, table_lookup[type(obj_to_insert)], type(obj_to_insert) + ) + return record_id + + user_id_1 = await insert_dummy(DUMMY_USER) + club_id_1 = await insert_dummy(DUMMY_CLUB) + await insert_dummy(UserXClub(user_id=user_id_1, club_id=club_id_1)) + + if real_user_id is not None: + await insert_dummy(UserXClub(user_id=real_user_id, club_id=club_id_1)) + + tournament_id_1 = await insert_dummy(DUMMY_TOURNAMENT.copy(update={'club_id': club_id_1})) + stage_id_1 = await insert_dummy(DUMMY_STAGE1.copy(update={'tournament_id': tournament_id_1})) + stage_id_2 = await insert_dummy(DUMMY_STAGE2.copy(update={'tournament_id': tournament_id_1})) + team_id_1 = await insert_dummy(DUMMY_TEAM1.copy(update={'tournament_id': tournament_id_1})) + team_id_2 = await insert_dummy(DUMMY_TEAM2.copy(update={'tournament_id': tournament_id_1})) + team_id_3 = await insert_dummy(DUMMY_TEAM3.copy(update={'tournament_id': tournament_id_1})) + team_id_4 = await insert_dummy(DUMMY_TEAM4.copy(update={'tournament_id': tournament_id_1})) + + await insert_dummy(DUMMY_PLAYER1.copy(update={'tournament_id': tournament_id_1})) + await insert_dummy(DUMMY_PLAYER2.copy(update={'tournament_id': tournament_id_1})) + await insert_dummy(DUMMY_PLAYER3.copy(update={'tournament_id': tournament_id_1})) + await insert_dummy(DUMMY_PLAYER4.copy(update={'tournament_id': tournament_id_1})) + await insert_dummy(DUMMY_PLAYER5.copy(update={'tournament_id': tournament_id_1})) + await insert_dummy(DUMMY_PLAYER6.copy(update={'tournament_id': tournament_id_1})) + await insert_dummy(DUMMY_PLAYER7.copy(update={'tournament_id': tournament_id_1})) + await insert_dummy(DUMMY_PLAYER8.copy(update={'tournament_id': tournament_id_1})) + await insert_dummy(DUMMY_PLAYER9.copy(update={'tournament_id': tournament_id_1})) + + round_id_1 = await insert_dummy(DUMMY_ROUND1.copy(update={'stage_id': stage_id_1})) + round_id_2 = await insert_dummy(DUMMY_ROUND2.copy(update={'stage_id': stage_id_1})) + round_id_3 = await insert_dummy(DUMMY_ROUND3.copy(update={'stage_id': stage_id_2})) + + await insert_dummy( + DUMMY_MATCH1.copy( + update={'round_id': round_id_1, 'team1_id': team_id_1, 'team2_id': team_id_2} + ), + ) + await insert_dummy( + DUMMY_MATCH2.copy( + update={'round_id': round_id_1, 'team1_id': team_id_3, 'team2_id': team_id_4} + ), + ) + await insert_dummy( + DUMMY_MATCH3.copy( + update={'round_id': round_id_2, 'team1_id': team_id_2, 'team2_id': team_id_4} + ), + ) + await insert_dummy( + DUMMY_MATCH4.copy( + update={'round_id': round_id_3, 'team1_id': team_id_3, 'team2_id': team_id_1} + ), + ) for tournament in await database.fetch_all(tournaments.select()): await recalculate_elo_for_tournament_id(tournament.id) # type: ignore[attr-defined] diff --git a/backend/tests/integration_tests/sql.py b/backend/tests/integration_tests/sql.py index ab620912..d9e0d13d 100644 --- a/backend/tests/integration_tests/sql.py +++ b/backend/tests/integration_tests/sql.py @@ -26,10 +26,8 @@ from bracket.schema import ( users, users_x_clubs, ) -from bracket.utils.conversion import to_string_mapping -from bracket.utils.db import fetch_one_parsed +from bracket.utils.db import insert_generic from bracket.utils.dummy_records import DUMMY_CLUB, DUMMY_TOURNAMENT -from bracket.utils.logging import logger from bracket.utils.types import BaseModelT, assert_some from tests.integration_tests.mocks import MOCK_USER, get_mock_token from tests.integration_tests.models import AuthContext @@ -44,18 +42,8 @@ async def assert_row_count_and_clear(table: Table, expected_rows: int) -> None: async def inserted_generic( data_model: BaseModelT, table: Table, return_type: Type[BaseModelT] ) -> AsyncIterator[BaseModelT]: - try: - last_record_id = await database.execute( - query=table.insert(), values=to_string_mapping(data_model) # type: ignore[arg-type] - ) - except: - logger.exception(f'Could not insert {type(data_model).__name__}') - raise + last_record_id, row_inserted = await insert_generic(database, data_model, table, return_type) - row_inserted = await fetch_one_parsed( - database, return_type, table.select().where(table.c.id == last_record_id) - ) - assert isinstance(row_inserted, return_type) try: yield row_inserted finally: