refactor(oidc): Get rid of OidcBackend

Now that we don't use it for tests, we don't need it anymore.

Signed-off-by: Kévin Commaille <zecakeh@tedomum.fr>
This commit is contained in:
Kévin Commaille
2025-02-26 15:34:59 +01:00
committed by Damir Jelić
parent 26cb805e0f
commit 377f34fae2
4 changed files with 242 additions and 497 deletions

View File

@@ -1,100 +0,0 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for that specific language governing permissions and
// limitations under the License.
//! Trait for defining an implementation for an OIDC backend.
//!
//! Used mostly for testing purposes.
use mas_oidc_client::{
requests::authorization_code::AuthorizationValidationData,
types::{
client_credentials::ClientCredentials,
iana::oauth::OAuthTokenTypeHint,
oidc::VerifiedProviderMetadata,
registration::{ClientRegistrationResponse, VerifiedClientMetadata},
},
};
use url::Url;
use super::{AuthorizationCode, OauthDiscoveryError, OidcError, OidcSessionTokens};
pub(crate) mod server;
pub(super) struct RefreshedSessionTokens {
pub access_token: String,
pub refresh_token: Option<String>,
}
#[async_trait::async_trait]
pub(super) trait OidcBackend: std::fmt::Debug + Send + Sync {
async fn discover(
&self,
insecure: bool,
) -> Result<VerifiedProviderMetadata, OauthDiscoveryError>;
async fn register_client(
&self,
registration_endpoint: &Url,
client_metadata: VerifiedClientMetadata,
software_statement: Option<String>,
) -> Result<ClientRegistrationResponse, OidcError>;
async fn trade_authorization_code_for_tokens(
&self,
provider_metadata: VerifiedProviderMetadata,
credentials: ClientCredentials,
auth_code: AuthorizationCode,
validation_data: AuthorizationValidationData,
) -> Result<OidcSessionTokens, OidcError>;
async fn refresh_access_token(
&self,
provider_metadata: VerifiedProviderMetadata,
credentials: ClientCredentials,
refresh_token: String,
) -> Result<RefreshedSessionTokens, OidcError>;
async fn revoke_token(
&self,
client_credentials: ClientCredentials,
revocation_endpoint: &Url,
token: String,
token_type_hint: Option<OAuthTokenTypeHint>,
) -> Result<(), OidcError>;
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
async fn request_device_authorization(
&self,
device_authorization_endpoint: Url,
client_id: oauth2::ClientId,
scopes: Vec<oauth2::Scope>,
) -> Result<
oauth2::StandardDeviceAuthorizationResponse,
oauth2::basic::BasicRequestTokenError<oauth2::HttpClientError<reqwest::Error>>,
>;
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
async fn exchange_device_code(
&self,
token_endpoint: Url,
client_id: oauth2::ClientId,
device_authorization_response: &oauth2::StandardDeviceAuthorizationResponse,
) -> Result<
OidcSessionTokens,
oauth2::RequestTokenError<
oauth2::HttpClientError<reqwest::Error>,
oauth2::DeviceCodeErrorResponse,
>,
>;
}

View File

@@ -1,274 +0,0 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for that specific language governing permissions and
// limitations under the License.
//! Actual implementation of the OIDC backend, using the mas_oidc_client
//! implementation.
use std::{future::Future, pin::Pin};
use chrono::Utc;
use http::StatusCode;
use mas_oidc_client::{
http_service::HttpService,
requests::{
authorization_code::{access_token_with_authorization_code, AuthorizationValidationData},
discovery::{discover, insecure_discover},
refresh_token::refresh_access_token,
registration::register_client,
revocation::revoke_token,
},
types::{
client_credentials::ClientCredentials,
iana::oauth::OAuthTokenTypeHint,
oidc::{ProviderMetadata, ProviderMetadataVerificationError, VerifiedProviderMetadata},
registration::{ClientRegistrationResponse, VerifiedClientMetadata},
},
};
use oauth2::{AsyncHttpClient, HttpClientError, HttpRequest, HttpResponse};
use ruma::api::client::discovery::{get_authentication_issuer, get_authorization_server_metadata};
use url::Url;
use super::{OidcBackend, OidcError, RefreshedSessionTokens};
use crate::{
authentication::oidc::{rng, AuthorizationCode, OauthDiscoveryError, OidcSessionTokens},
http_client::HttpClient,
Client,
};
#[derive(Debug)]
pub(crate) struct OidcServer {
client: Client,
}
impl OidcServer {
pub(crate) fn new(client: Client) -> Self {
Self { client }
}
fn http_client(&self) -> &HttpClient {
&self.client.inner.http_client
}
fn http_service(&self) -> HttpService {
HttpService::new(self.http_client().clone())
}
}
#[async_trait::async_trait]
impl OidcBackend for OidcServer {
async fn discover(
&self,
insecure: bool,
) -> Result<VerifiedProviderMetadata, OauthDiscoveryError> {
match self.client.send(get_authorization_server_metadata::msc2965::Request::new()).await {
Ok(response) => {
let metadata = response.metadata.deserialize_as::<ProviderMetadata>()?;
let result = if insecure {
metadata.insecure_verify_metadata()
} else {
// The mas-oidc-client method needs to compare the issuer for validation. It's a
// bit unnecessary because we take it from the metadata, oh well.
let issuer = metadata.issuer.clone().ok_or(
mas_oidc_client::error::DiscoveryError::Validation(
ProviderMetadataVerificationError::MissingIssuer,
),
)?;
metadata.validate(&issuer)
};
return Ok(result.map_err(mas_oidc_client::error::DiscoveryError::Validation)?);
}
Err(error)
if error
.as_client_api_error()
.is_some_and(|err| err.status_code == StatusCode::NOT_FOUND) =>
{
// Fallback to OIDC discovery.
}
Err(error) => return Err(error.into()),
};
// TODO: remove this fallback behavior when the metadata endpoint has wider
// support.
#[allow(deprecated)]
let issuer =
match self.client.send(get_authentication_issuer::msc2965::Request::new()).await {
Ok(response) => response.issuer,
Err(error)
if error
.as_client_api_error()
.is_some_and(|err| err.status_code == StatusCode::NOT_FOUND) =>
{
return Err(OauthDiscoveryError::NotSupported);
}
Err(error) => return Err(error.into()),
};
if insecure {
insecure_discover(&self.http_service(), &issuer).await.map_err(Into::into)
} else {
discover(&self.http_service(), &issuer).await.map_err(Into::into)
}
}
async fn trade_authorization_code_for_tokens(
&self,
provider_metadata: VerifiedProviderMetadata,
credentials: ClientCredentials,
auth_code: AuthorizationCode,
validation_data: AuthorizationValidationData,
) -> Result<OidcSessionTokens, OidcError> {
let (response, _) = access_token_with_authorization_code(
&self.http_service(),
credentials.clone(),
provider_metadata.token_endpoint(),
auth_code.code,
validation_data,
None,
Utc::now(),
&mut rng()?,
)
.await?;
Ok(OidcSessionTokens {
access_token: response.access_token,
refresh_token: response.refresh_token,
})
}
async fn refresh_access_token(
&self,
provider_metadata: VerifiedProviderMetadata,
credentials: ClientCredentials,
refresh_token: String,
) -> Result<RefreshedSessionTokens, OidcError> {
refresh_access_token(
&self.http_service(),
credentials,
provider_metadata.token_endpoint(),
refresh_token,
None,
None,
None,
Utc::now(),
&mut rng()?,
)
.await
.map(|(response, _id_token)| RefreshedSessionTokens {
access_token: response.access_token,
refresh_token: response.refresh_token,
})
.map_err(Into::into)
}
async fn register_client(
&self,
registration_endpoint: &Url,
client_metadata: VerifiedClientMetadata,
software_statement: Option<String>,
) -> Result<ClientRegistrationResponse, OidcError> {
register_client(
&self.http_service(),
registration_endpoint,
client_metadata,
software_statement,
)
.await
.map_err(Into::into)
}
async fn revoke_token(
&self,
client_credentials: ClientCredentials,
revocation_endpoint: &Url,
token: String,
token_type_hint: Option<OAuthTokenTypeHint>,
) -> Result<(), OidcError> {
Ok(revoke_token(
&self.http_service(),
client_credentials,
revocation_endpoint,
token,
token_type_hint,
Utc::now(),
&mut rng()?,
)
.await?)
}
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
async fn request_device_authorization(
&self,
device_authorization_endpoint: Url,
client_id: oauth2::ClientId,
scopes: Vec<oauth2::Scope>,
) -> Result<
oauth2::StandardDeviceAuthorizationResponse,
oauth2::basic::BasicRequestTokenError<HttpClientError<reqwest::Error>>,
> {
let device_authorization_url =
oauth2::DeviceAuthorizationUrl::from_url(device_authorization_endpoint);
oauth2::basic::BasicClient::new(client_id)
.set_device_authorization_url(device_authorization_url)
.exchange_device_code()
.add_scopes(scopes)
.request_async(self.http_client())
.await
}
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
async fn exchange_device_code(
&self,
token_endpoint: Url,
client_id: oauth2::ClientId,
device_authorization_response: &oauth2::StandardDeviceAuthorizationResponse,
) -> Result<
OidcSessionTokens,
oauth2::RequestTokenError<HttpClientError<reqwest::Error>, oauth2::DeviceCodeErrorResponse>,
> {
use oauth2::TokenResponse;
let token_uri = oauth2::TokenUrl::from_url(token_endpoint);
let response = oauth2::basic::BasicClient::new(client_id)
.set_token_uri(token_uri)
.exchange_device_access_token(device_authorization_response)
.request_async(self.http_client(), tokio::time::sleep, None)
.await?;
let tokens = OidcSessionTokens {
access_token: response.access_token().secret().to_owned(),
refresh_token: response.refresh_token().map(|t| t.secret().to_owned()),
};
Ok(tokens)
}
}
impl<'c> AsyncHttpClient<'c> for HttpClient {
type Error = HttpClientError<reqwest::Error>;
type Future =
Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + Send + Sync + 'c>>;
fn call(&'c self, request: HttpRequest) -> Self::Future {
Box::pin(async move {
let response = self.inner.call(request).await?;
Ok(response)
})
}
}

View File

@@ -146,22 +146,31 @@
//! [`AuthenticateError::InsufficientScope`]: ruma::api::client::error::AuthenticateError
//! [`examples/oidc_cli`]: https://github.com/matrix-org/matrix-rust-sdk/tree/main/examples/oidc_cli
use std::{collections::HashMap, fmt, sync::Arc};
use std::{collections::HashMap, fmt, future::Future, pin::Pin, sync::Arc};
use as_variant::as_variant;
use chrono::Utc;
use eyeball::SharedObservable;
use futures_core::Stream;
pub use mas_oidc_client::{error, requests, types};
use mas_oidc_client::{
http_service::HttpService,
requests::{
account_management::{build_account_management_url, AccountManagementActionFull},
authorization_code::AuthorizationValidationData,
authorization_code::{access_token_with_authorization_code, AuthorizationValidationData},
discovery::{discover, insecure_discover},
refresh_token::refresh_access_token,
registration::register_client,
revocation::revoke_token,
},
types::{
client_credentials::ClientCredentials,
errors::{ClientError, ClientErrorCode::AccessDenied},
iana::oauth::OAuthTokenTypeHint,
oidc::{AccountManagementAction, VerifiedProviderMetadata},
oidc::{
AccountManagementAction, ProviderMetadata, ProviderMetadataVerificationError,
VerifiedProviderMetadata,
},
registration::{ClientRegistrationResponse, VerifiedClientMetadata},
requests::Prompt,
scope::{MatrixApiScopeToken, ScopeToken},
@@ -170,7 +179,9 @@ use mas_oidc_client::{
#[cfg(feature = "e2e-encryption")]
use matrix_sdk_base::crypto::types::qr_login::QrCodeData;
use matrix_sdk_base::{once_cell::sync::OnceCell, SessionMeta};
use oauth2::{AsyncHttpClient, HttpClientError, HttpRequest, HttpResponse};
use rand::{rngs::StdRng, Rng, SeedableRng};
use ruma::api::client::discovery::{get_authentication_issuer, get_authorization_server_metadata};
use serde::{Deserialize, Serialize};
use sha2::Digest as _;
use thiserror::Error;
@@ -179,7 +190,6 @@ use tracing::{debug, error, info, instrument, trace, warn};
use url::Url;
mod auth_code_builder;
mod backend;
mod cross_process;
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
pub mod qrcode;
@@ -192,13 +202,14 @@ pub use self::{
cross_process::CrossProcessRefreshLockError,
};
use self::{
backend::{server::OidcServer, OidcBackend},
cross_process::{CrossProcessRefreshLockGuard, CrossProcessRefreshManager},
qrcode::LoginWithQrCode,
registrations::{ClientId, OidcRegistrations},
};
use super::AuthData;
use crate::{client::SessionChange, Client, HttpError, RefreshTokenError, Result};
use crate::{
client::SessionChange, http_client::HttpClient, Client, HttpError, RefreshTokenError, Result,
};
pub(crate) struct OidcCtx {
/// Lock and state when multiple processes may refresh an OIDC session.
@@ -251,20 +262,25 @@ impl fmt::Debug for OidcAuthData {
pub struct Oidc {
/// The underlying Matrix API client.
client: Client,
/// The implementation of the OIDC backend.
backend: Arc<dyn OidcBackend>,
}
impl Oidc {
pub(crate) fn new(client: Client) -> Self {
Self { client: client.clone(), backend: Arc::new(OidcServer::new(client)) }
Self { client }
}
fn ctx(&self) -> &OidcCtx {
&self.client.inner.auth_ctx.oidc
}
fn http_client(&self) -> &HttpClient {
&self.client.inner.http_client
}
fn http_service(&self) -> HttpService {
HttpService::new(self.http_client().clone())
}
/// Enable a cross-process store lock on the state store, to coordinate
/// refreshes across different processes.
pub async fn enable_cross_process_refresh_lock(
@@ -630,7 +646,55 @@ impl Oidc {
/// Returns an error if a problem occurred when fetching or validating the
/// metadata.
pub async fn provider_metadata(&self) -> Result<VerifiedProviderMetadata, OauthDiscoveryError> {
self.backend.discover(self.ctx().insecure_discover).await
match self.client.send(get_authorization_server_metadata::msc2965::Request::new()).await {
Ok(response) => {
let metadata = response.metadata.deserialize_as::<ProviderMetadata>()?;
let result = if self.ctx().insecure_discover {
metadata.insecure_verify_metadata()
} else {
// The mas-oidc-client method needs to compare the issuer for validation. It's a
// bit unnecessary because we take it from the metadata, oh well.
let issuer =
metadata.issuer.clone().ok_or(error::DiscoveryError::Validation(
ProviderMetadataVerificationError::MissingIssuer,
))?;
metadata.validate(&issuer)
};
return Ok(result.map_err(error::DiscoveryError::Validation)?);
}
Err(error)
if error
.as_client_api_error()
.is_some_and(|err| err.status_code == http::StatusCode::NOT_FOUND) =>
{
// Fallback to OIDC discovery.
}
Err(error) => return Err(error.into()),
};
// TODO: remove this fallback behavior when the metadata endpoint has wider
// support.
#[allow(deprecated)]
let issuer =
match self.client.send(get_authentication_issuer::msc2965::Request::new()).await {
Ok(response) => response.issuer,
Err(error)
if error
.as_client_api_error()
.is_some_and(|err| err.status_code == http::StatusCode::NOT_FOUND) =>
{
return Err(OauthDiscoveryError::NotSupported);
}
Err(error) => return Err(error.into()),
};
if self.ctx().insecure_discover {
insecure_discover(&self.http_service(), &issuer).await.map_err(Into::into)
} else {
discover(&self.http_service(), &issuer).await.map_err(Into::into)
}
}
/// The OpenID Connect unique identifier of this client obtained after
@@ -832,10 +896,13 @@ impl Oidc {
.as_ref()
.ok_or(OidcError::NoRegistrationSupport)?;
let registration_response = self
.backend
.register_client(registration_endpoint, client_metadata, software_statement)
.await?;
let registration_response = register_client(
&self.http_service(),
registration_endpoint,
client_metadata,
software_statement,
)
.await?;
// The format of the credentials changes according to the client metadata that
// was sent. Public clients only get a client ID.
@@ -996,6 +1063,24 @@ impl Oidc {
Ok(())
}
/// The scopes to request for logging in.
fn login_scopes(device_id: Option<String>) -> Result<[ScopeToken; 3], OidcError> {
// Generate the device ID if it is not provided.
let device_id = device_id.unwrap_or_else(|| {
rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.map(char::from)
.take(10)
.collect::<String>()
});
Ok([
ScopeToken::Openid,
ScopeToken::MatrixApi(MatrixApiScopeToken::Full),
ScopeToken::try_with_matrix_device(device_id).or(Err(OidcError::InvalidDeviceId))?,
])
}
/// Login via OpenID Connect with the Authorization Code flow.
///
/// This should be called after [`Oidc::register_client()`] or
@@ -1072,24 +1157,7 @@ impl Oidc {
redirect_uri: Url,
device_id: Option<String>,
) -> Result<OidcAuthCodeUrlBuilder, OidcError> {
// Generate the device ID if it is not provided.
let device_id = if let Some(device_id) = device_id {
device_id
} else {
rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.map(char::from)
.take(10)
.collect::<String>()
};
let scope = [
ScopeToken::Openid,
ScopeToken::MatrixApi(MatrixApiScopeToken::Full),
ScopeToken::try_with_matrix_device(device_id).or(Err(OidcError::InvalidDeviceId))?,
]
.into_iter()
.collect();
let scope = Self::login_scopes(device_id)?.into_iter().collect();
Ok(OidcAuthCodeUrlBuilder::new(self.clone(), scope, redirect_uri))
}
@@ -1187,17 +1255,22 @@ impl Oidc {
let provider_metadata = self.provider_metadata().await?;
let session_tokens = self
.backend
.trade_authorization_code_for_tokens(
provider_metadata,
data.credentials(),
auth_code,
validation_data,
)
.await?;
let (response, _) = access_token_with_authorization_code(
&self.http_service(),
data.credentials(),
provider_metadata.token_endpoint(),
auth_code.code,
validation_data,
None,
Utc::now(),
&mut rng()?,
)
.await?;
self.set_session_tokens(session_tokens);
self.set_session_tokens(OidcSessionTokens {
access_token: response.access_token,
refresh_token: response.refresh_token,
});
Ok(())
}
@@ -1223,6 +1296,66 @@ impl Oidc {
}
}
/// Request codes from the authorization server for logging in with another
/// device.
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
async fn request_device_authorization(
&self,
device_id: Option<String>,
) -> Result<oauth2::StandardDeviceAuthorizationResponse, qrcode::DeviceAuthorizationOauthError>
{
let scopes = Self::login_scopes(device_id)?
.into_iter()
.map(|scope| oauth2::Scope::new(scope.to_string()));
let client_id =
oauth2::ClientId::new(self.client_id().ok_or(OidcError::NotRegistered)?.0.clone());
let server_metadata = self.provider_metadata().await.map_err(OidcError::from)?;
let device_authorization_url = server_metadata
.device_authorization_endpoint
.clone()
.map(oauth2::DeviceAuthorizationUrl::from_url)
.ok_or(qrcode::DeviceAuthorizationOauthError::NoDeviceAuthorizationEndpoint)?;
let response = oauth2::basic::BasicClient::new(client_id)
.set_device_authorization_url(device_authorization_url)
.exchange_device_code()
.add_scopes(scopes)
.request_async(self.http_client())
.await?;
Ok(response)
}
/// Exchange the device code against an access token.
#[cfg(all(feature = "e2e-encryption", not(target_arch = "wasm32")))]
async fn exchange_device_code(
&self,
device_authorization_response: &oauth2::StandardDeviceAuthorizationResponse,
) -> Result<(), qrcode::DeviceAuthorizationOauthError> {
use oauth2::TokenResponse;
let client_id =
oauth2::ClientId::new(self.client_id().ok_or(OidcError::NotRegistered)?.0.clone());
let server_metadata = self.provider_metadata().await.map_err(OidcError::from)?;
let token_uri = oauth2::TokenUrl::from_url(server_metadata.token_endpoint().clone());
let response = oauth2::basic::BasicClient::new(client_id)
.set_token_uri(token_uri)
.exchange_device_access_token(device_authorization_response)
.request_async(self.http_client(), tokio::time::sleep, None)
.await?;
self.set_session_tokens(OidcSessionTokens {
access_token: response.access_token().secret().to_owned(),
refresh_token: response.refresh_token().map(|t| t.secret().to_owned()),
});
Ok(())
}
async fn refresh_access_token_inner(
self,
refresh_token: String,
@@ -1235,24 +1368,32 @@ impl Oidc {
hash_str(&refresh_token)
);
let new_tokens = self
.backend
.refresh_access_token(provider_metadata, credentials, refresh_token.clone())
.await?;
let (response, _) = refresh_access_token(
&self.http_service(),
credentials,
provider_metadata.token_endpoint(),
refresh_token.clone(),
None,
None,
None,
Utc::now(),
&mut rng()?,
)
.await?;
trace!(
"Token refresh: new refresh_token: {} / access_token: {:x}",
new_tokens
response
.refresh_token
.as_deref()
.map(|token| format!("{:x}", hash_str(token)))
.unwrap_or_else(|| "<none>".to_owned()),
hash_str(&new_tokens.access_token)
hash_str(&response.access_token)
);
let tokens = OidcSessionTokens {
access_token: new_tokens.access_token,
refresh_token: new_tokens.refresh_token.clone().or(Some(refresh_token)),
access_token: response.access_token,
refresh_token: response.refresh_token.or(Some(refresh_token)),
};
self.set_session_tokens(tokens.clone());
@@ -1417,14 +1558,16 @@ impl Oidc {
let tokens = self.session_tokens().ok_or(OidcError::NotAuthenticated)?;
// Revoke the access token, it should revoke both tokens.
self.backend
.revoke_token(
client_credentials,
revocation_endpoint,
tokens.access_token,
Some(OAuthTokenTypeHint::AccessToken),
)
.await?;
revoke_token(
&self.http_service(),
client_credentials,
revocation_endpoint,
tokens.access_token,
Some(OAuthTokenTypeHint::AccessToken),
Utc::now(),
&mut rng()?,
)
.await?;
if let Some(manager) = self.ctx().cross_process_token_refresh_manager.get() {
manager.on_logout().await?;
@@ -1670,3 +1813,18 @@ fn rng() -> Result<StdRng, OidcError> {
fn hash_str(x: &str) -> impl fmt::LowerHex {
sha2::Sha256::new().chain_update(x).finalize()
}
impl<'c> AsyncHttpClient<'c> for HttpClient {
type Error = HttpClientError<reqwest::Error>;
type Future =
Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + Send + Sync + 'c>>;
fn call(&'c self, request: HttpRequest) -> Self::Future {
Box::pin(async move {
let response = self.inner.call(request).await?;
Ok(response)
})
}
}

View File

@@ -16,16 +16,13 @@ use std::future::IntoFuture;
use eyeball::SharedObservable;
use futures_core::Stream;
use mas_oidc_client::types::{
registration::VerifiedClientMetadata,
scope::{MatrixApiScopeToken, ScopeToken},
};
use mas_oidc_client::types::registration::VerifiedClientMetadata;
use matrix_sdk_base::{
boxed_into_future,
crypto::types::qr_login::{QrCodeData, QrCodeMode},
SessionMeta,
};
use oauth2::{DeviceCodeErrorResponseType, Scope, StandardDeviceAuthorizationResponse};
use oauth2::{DeviceCodeErrorResponseType, StandardDeviceAuthorizationResponse};
use ruma::OwnedDeviceId;
use tracing::trace;
use vodozemac::{ecies::CheckCode, Curve25519PublicKey};
@@ -37,10 +34,7 @@ use super::{
};
#[cfg(doc)]
use crate::authentication::oidc::Oidc;
use crate::{
authentication::oidc::{OidcError, OidcSessionTokens},
Client,
};
use crate::Client;
async fn send_unexpected_message_error(
channel: &mut EstablishedSecureChannel,
@@ -167,32 +161,28 @@ impl<'a> IntoFuture for LoginWithQrCode<'a> {
// Let's now wait for the access token to be provided to use by the OAuth 2.0
// authorization server.
trace!("Waiting for the OAuth 2.0 authorization server to give us the access token.");
let session_tokens = match self.wait_for_tokens(&auth_grant_response).await {
Ok(t) => t,
Err(e) => {
// If we received an error, and it's one of the ones we should report to the
// other side, do so now.
if let Some(e) = e.as_request_token_error() {
match e {
DeviceCodeErrorResponseType::AccessDenied => {
channel.send_json(QrAuthMessage::LoginDeclined).await?;
}
DeviceCodeErrorResponseType::ExpiredToken => {
channel
.send_json(QrAuthMessage::LoginFailure {
reason: LoginFailureReason::AuthorizationExpired,
homeserver: None,
})
.await?;
}
_ => (),
if let Err(e) = self.wait_for_tokens(&auth_grant_response).await {
// If we received an error, and it's one of the ones we should report to the
// other side, do so now.
if let Some(e) = e.as_request_token_error() {
match e {
DeviceCodeErrorResponseType::AccessDenied => {
channel.send_json(QrAuthMessage::LoginDeclined).await?;
}
DeviceCodeErrorResponseType::ExpiredToken => {
channel
.send_json(QrAuthMessage::LoginFailure {
reason: LoginFailureReason::AuthorizationExpired,
homeserver: None,
})
.await?;
}
_ => (),
}
return Err(e.into());
}
return Err(e.into());
};
self.client.oidc().set_session_tokens(session_tokens);
// We only received an access token from the OAuth 2.0 authorization server, we
// have no clue who we are, so we need to figure out our user ID
@@ -303,47 +293,18 @@ impl<'a> LoginWithQrCode<'a> {
&self,
device_id: Curve25519PublicKey,
) -> Result<StandardDeviceAuthorizationResponse, DeviceAuthorizationOauthError> {
let scopes = [
ScopeToken::MatrixApi(MatrixApiScopeToken::Full),
ScopeToken::try_with_matrix_device(device_id.to_base64()).expect(
"We should be able to create a scope token from a \
Curve25519 public key encoded as base64",
),
]
.into_iter()
.map(|scope| Scope::new(scope.to_string()))
.collect();
let oidc = self.client.oidc();
let client_id =
oauth2::ClientId::new(oidc.client_id().ok_or(OidcError::NotRegistered)?.0.clone());
let server_metadata = oidc.provider_metadata().await.map_err(OidcError::from)?;
let device_authorization_endpoint =
server_metadata
.device_authorization_endpoint
.clone()
.ok_or(DeviceAuthorizationOauthError::NoDeviceAuthorizationEndpoint)?;
let response = oidc
.backend
.request_device_authorization(device_authorization_endpoint, client_id, scopes)
.await?;
let response = oidc.request_device_authorization(Some(device_id.to_base64())).await?;
Ok(response)
}
async fn wait_for_tokens(
&self,
auth_response: &StandardDeviceAuthorizationResponse,
) -> Result<OidcSessionTokens, DeviceAuthorizationOauthError> {
) -> Result<(), DeviceAuthorizationOauthError> {
let oidc = self.client.oidc();
let client_id =
oauth2::ClientId::new(oidc.client_id().ok_or(OidcError::NotRegistered)?.0.clone());
let server_metadata = oidc.provider_metadata().await.map_err(OidcError::from)?;
let token_endpoint = server_metadata.token_endpoint().clone();
let tokens =
oidc.backend.exchange_device_code(token_endpoint, client_id, auth_response).await?;
Ok(tokens)
oidc.exchange_device_code(auth_response).await?;
Ok(())
}
}