diff --git a/backend/bracket/logic/scheduling/upcoming_matches.py b/backend/bracket/logic/scheduling/upcoming_matches.py index 62b5d5ea..d51dbb63 100644 --- a/backend/bracket/logic/scheduling/upcoming_matches.py +++ b/backend/bracket/logic/scheduling/upcoming_matches.py @@ -3,26 +3,43 @@ from fastapi import HTTPException from bracket.logic.scheduling.ladder_teams import get_possible_upcoming_matches_for_swiss from bracket.models.db.match import MatchFilter, SuggestedMatch from bracket.models.db.round import Round -from bracket.models.db.stage_item import StageType +from bracket.models.db.stage_item import StageItem, StageType +from bracket.models.db.util import RoundWithMatches, StageItemWithRounds from bracket.sql.rounds import get_rounds_for_stage_item from bracket.sql.stages import get_full_tournament_details from bracket.sql.teams import get_teams_with_members -from bracket.utils.id_types import TournamentId +from bracket.utils.id_types import StageItemId, TournamentId from bracket.utils.types import assert_some -async def get_upcoming_matches_for_swiss_round( - match_filter: MatchFilter, round_: Round, tournament_id: TournamentId -) -> list[SuggestedMatch]: - [stage] = await get_full_tournament_details( - tournament_id, stage_item_ids={round_.stage_item_id} +async def get_draft_round_in_stage_item( + tournament_id: TournamentId, + stage_item_id: StageItemId, +) -> tuple[RoundWithMatches, StageItemWithRounds]: + [stage] = await get_full_tournament_details(tournament_id, stage_item_ids={stage_item_id}) + draft_round, stage_item = next( + ( + (round_, stage_item) + for stage_item in stage.stage_items + for round_ in stage_item.rounds + if round_.is_draft + ), + (None, None), ) - assert len(stage.stage_items) == 1 - [stage_item] = stage.stage_items + if draft_round is None or stage_item is None: + raise HTTPException(400, "Expected stage item to be of type SWISS.") + return draft_round, stage_item + +async def get_upcoming_matches_for_swiss_round( + match_filter: MatchFilter, stage_item: StageItem, round_: Round, tournament_id: TournamentId +) -> list[SuggestedMatch]: if stage_item.type is not StageType.SWISS: raise HTTPException(400, "Expected stage item to be of type SWISS.") + if not round_.is_draft: + raise HTTPException(400, "There is no draft round, so no matches can be scheduled.") + rounds = await get_rounds_for_stage_item(tournament_id, assert_some(stage_item.id)) teams = await get_teams_with_members(tournament_id, only_active_teams=True) diff --git a/backend/bracket/routes/matches.py b/backend/bracket/routes/matches.py index c11b4e40..fd730930 100644 --- a/backend/bracket/routes/matches.py +++ b/backend/bracket/routes/matches.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends from bracket.logic.planning.matches import ( get_scheduled_matches, @@ -10,6 +10,7 @@ from bracket.logic.ranking.elo import ( recalculate_ranking_for_stage_item_id, ) from bracket.logic.scheduling.upcoming_matches import ( + get_draft_round_in_stage_item, get_upcoming_matches_for_swiss_round, ) from bracket.models.db.match import ( @@ -21,36 +22,34 @@ from bracket.models.db.match import ( MatchRescheduleBody, SuggestedMatch, ) -from bracket.models.db.round import Round from bracket.models.db.user import UserPublic -from bracket.models.db.util import RoundWithMatches from bracket.routes.auth import user_authenticated_for_tournament from bracket.routes.models import SingleMatchResponse, SuccessResponse, UpcomingMatchesResponse -from bracket.routes.util import match_dependency, round_dependency, round_with_matches_dependency +from bracket.routes.util import match_dependency from bracket.sql.courts import get_all_courts_in_tournament from bracket.sql.matches import sql_create_match, sql_delete_match, sql_update_match from bracket.sql.rounds import get_round_by_id from bracket.sql.stages import get_full_tournament_details from bracket.sql.tournaments import sql_get_tournament from bracket.sql.validation import check_foreign_keys_belong_to_tournament -from bracket.utils.id_types import MatchId, TournamentId +from bracket.utils.id_types import MatchId, StageItemId, TournamentId from bracket.utils.types import assert_some router = APIRouter() @router.get( - "/tournaments/{tournament_id}/rounds/{round_id}/upcoming_matches", + "/tournaments/{tournament_id}/stage_items/{stage_item_id}/upcoming_matches", response_model=UpcomingMatchesResponse, ) async def get_matches_to_schedule( tournament_id: TournamentId, + stage_item_id: StageItemId, elo_diff_threshold: int = 200, iterations: int = 200, only_recommended: bool = False, limit: int = 50, _: UserPublic = Depends(user_authenticated_for_tournament), - round_: Round = Depends(round_dependency), ) -> UpcomingMatchesResponse: match_filter = MatchFilter( elo_diff_threshold=elo_diff_threshold, @@ -59,11 +58,12 @@ async def get_matches_to_schedule( iterations=iterations, ) - if not round_.is_draft: - raise HTTPException(400, "There is no draft round, so no matches can be scheduled.") + draft_round, stage_item = await get_draft_round_in_stage_item(tournament_id, stage_item_id) return UpcomingMatchesResponse( - data=await get_upcoming_matches_for_swiss_round(match_filter, round_, tournament_id) + data=await get_upcoming_matches_for_swiss_round( + match_filter, stage_item, draft_round, tournament_id + ) ) @@ -123,20 +123,17 @@ async def reschedule_match( @router.post( - "/tournaments/{tournament_id}/rounds/{round_id}/schedule_auto", + "/tournaments/{tournament_id}/stage_items/{stage_item_id}/schedule_auto", response_model=SuccessResponse, ) async def create_matches_automatically( tournament_id: TournamentId, + stage_item_id: StageItemId, elo_diff_threshold: int = 100, iterations: int = 200, only_recommended: bool = False, _: UserPublic = Depends(user_authenticated_for_tournament), - round_: RoundWithMatches = Depends(round_with_matches_dependency), ) -> SuccessResponse: - if not round_.is_draft: - raise HTTPException(400, "There is no draft round, so no matches can be scheduled.") - match_filter = MatchFilter( elo_diff_threshold=elo_diff_threshold, only_recommended=only_recommended, @@ -144,13 +141,14 @@ async def create_matches_automatically( iterations=iterations, ) + draft_round, stage_item = await get_draft_round_in_stage_item(tournament_id, stage_item_id) courts = await get_all_courts_in_tournament(tournament_id) tournament = await sql_get_tournament(tournament_id) - limit = len(courts) - len(round_.matches) + limit = len(courts) - len(draft_round.matches) for __ in range(limit): all_matches_to_schedule = await get_upcoming_matches_for_swiss_round( - match_filter, round_, tournament_id + match_filter, stage_item, draft_round, tournament_id ) if len(all_matches_to_schedule) < 1: break @@ -158,10 +156,10 @@ async def create_matches_automatically( match = all_matches_to_schedule[0] assert isinstance(match, SuggestedMatch) - assert round_.id and match.team1.id and match.team2.id + assert draft_round.id and match.team1.id and match.team2.id await sql_create_match( MatchCreateBody( - round_id=round_.id, + round_id=draft_round.id, team1_id=match.team1.id, team2_id=match.team2.id, court_id=None, diff --git a/backend/tests/integration_tests/api/auto_scheduling_matches_test.py b/backend/tests/integration_tests/api/auto_scheduling_matches_test.py index 6d46b1cd..a9734624 100644 --- a/backend/tests/integration_tests/api/auto_scheduling_matches_test.py +++ b/backend/tests/integration_tests/api/auto_scheduling_matches_test.py @@ -66,13 +66,13 @@ async def test_schedule_matches_auto( ], ), ) - round_1_id = await sql_create_round( + await sql_create_round( RoundToInsert(stage_item_id=stage_item_1.id, name="", is_draft=True, is_active=False), ) response = await send_tournament_request( HTTPMethod.POST, - f"rounds/{round_1_id}/schedule_auto", + f"stage_items/{stage_item_1.id}/schedule_auto", auth_context, ) stages = await get_full_tournament_details(tournament_id) diff --git a/backend/tests/integration_tests/api/matches_test.py b/backend/tests/integration_tests/api/matches_test.py index 3b642a34..81a81d57 100644 --- a/backend/tests/integration_tests/api/matches_test.py +++ b/backend/tests/integration_tests/api/matches_test.py @@ -271,7 +271,7 @@ async def test_upcoming_matches_endpoint( "stage_item_id": stage_item_inserted.id, } ) - ) as round_inserted, + ), inserted_team( DUMMY_TEAM1.model_copy( update={"tournament_id": auth_context.tournament.id, "elo_score": Decimal("1150.0")} @@ -308,7 +308,10 @@ async def test_upcoming_matches_endpoint( ) as player_inserted_4, ): json_response = await send_tournament_request( - HTTPMethod.GET, f"rounds/{round_inserted.id}/upcoming_matches", auth_context, {} + HTTPMethod.GET, + f"stage_items/{stage_item_inserted.id}/upcoming_matches", + auth_context, + {}, ) assert json_response == { "data": [ diff --git a/frontend/src/components/buttons/create_matches_auto.tsx b/frontend/src/components/buttons/create_matches_auto.tsx index 7b105e63..7378f5fc 100644 --- a/frontend/src/components/buttons/create_matches_auto.tsx +++ b/frontend/src/components/buttons/create_matches_auto.tsx @@ -12,19 +12,16 @@ export function AutoCreateMatchesButton({ tournamentData, swrStagesResponse, swrUpcomingMatchesResponse, - roundId, + stageItemId, schedulerSettings, }: { schedulerSettings: SchedulerSettings; - roundId: number; + stageItemId: number; tournamentData: Tournament; swrStagesResponse: SWRResponse; swrUpcomingMatchesResponse: SWRResponse; }) { const { t } = useTranslation(); - if (roundId == null) { - return null; - } return (