diff --git a/backend/bracket/logic/sso.py b/backend/bracket/logic/sso.py index b5004d2a..618ff416 100644 --- a/backend/bracket/logic/sso.py +++ b/backend/bracket/logic/sso.py @@ -22,6 +22,30 @@ async def get_discovery_document(discovery_url: str) -> DiscoveryDocument: } +async def build_openid_sso(sso_config: SSOConfig) -> SSOBase: + assert sso_config.openid_discovery_url is not None, ( + "`openid_discovery_url` should be set for OpenID SSO" + ) + assert sso_config.openid_scopes is not None, "`openid_scopes` should be set for OpenID SSO" + + def convert_openid(response: dict[str, Any], _client: AsyncClient | None) -> OpenID: + return OpenID(display_name=response["sub"]) + + GenericSSO = create_provider( + name="oidc", + discovery_document=await get_discovery_document(sso_config.openid_discovery_url), + response_convertor=convert_openid, + ) + + return GenericSSO( + client_id=sso_config.client_id, + client_secret=sso_config.client_secret, + redirect_uri=sso_config.redirect_uri, + allow_insecure_http=sso_config.allow_insecure_http, + scope=sso_config.openid_scopes.split(","), + ) + + async def build_sso(sso_config: SSOConfig) -> SSOBase: match sso_config.provider: case SSOProvider.google: @@ -39,29 +63,7 @@ async def build_sso(sso_config: SSOConfig) -> SSOBase: allow_insecure_http=sso_config.allow_insecure_http, ) case SSOProvider.openid: - assert sso_config.openid_discovery_url is not None, ( - "`openid_discovery_url` should be set for OpenID SSO" - ) - assert sso_config.openid_scopes is not None, ( - "`openid_scopes` should be set for OpenID SSO" - ) - - def convert_openid(response: dict[str, Any], _client: AsyncClient | None) -> OpenID: - return OpenID(display_name=response["sub"]) - - GenericSSO = create_provider( - name="oidc", - discovery_document=await get_discovery_document(sso_config.openid_discovery_url), - response_convertor=convert_openid, - ) - - return GenericSSO( - client_id=sso_config.client_id, - client_secret=sso_config.client_secret, - redirect_uri=sso_config.redirect_uri, - allow_insecure_http=sso_config.allow_insecure_http, - scope=sso_config.openid_scopes.split(","), - ) + return await build_openid_sso(sso_config) @cache diff --git a/backend/bracket/routes/auth.py b/backend/bracket/routes/auth.py index 2d2c748f..234424cc 100644 --- a/backend/bracket/routes/auth.py +++ b/backend/bracket/routes/auth.py @@ -1,4 +1,3 @@ -import os from typing import Any import jwt