diff --git a/crates/matrix-sdk/src/authentication/oauth/qrcode/grant.rs b/crates/matrix-sdk/src/authentication/oauth/qrcode/grant.rs index a79a7bb25..8a9754227 100644 --- a/crates/matrix-sdk/src/authentication/oauth/qrcode/grant.rs +++ b/crates/matrix-sdk/src/authentication/oauth/qrcode/grant.rs @@ -789,11 +789,11 @@ mod test { ); } - #[async_test] - async fn test_grant_login_with_generated_qr_code() { + async fn test_grant_login_with_generated_qr_code(msc_4388: bool) { let server = MatrixMockServer::new().await; let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await; + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, msc_4388) + .await; debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); let device_authorization_grant = AuthorizationGrant { @@ -838,10 +838,15 @@ mod test { // Prepare the login granting future. let oauth = alice.oauth(); - let grant = oauth + let mut grant = oauth .grant_login_with_qr_code() .device_creation_timeout(Duration::from_secs(2)) .generate(); + + if msc_4388 { + grant.with_msc4388_support(); + } + let secrets_bundle = export_secrets_bundle(&alice) .await .expect("Alice should be able to export the secrets bundle"); @@ -934,10 +939,20 @@ mod test { } #[async_test] - async fn test_grant_login_with_scanned_qr_code() { + async fn test_grant_login_with_generated_qr_code_msc_4108() { + test_grant_login_with_generated_qr_code(false).await; + } + + #[async_test] + async fn test_grant_login_with_generated_qr_code_msc_4388() { + test_grant_login_with_generated_qr_code(true).await; + } + + async fn test_grant_login_with_scanned_qr_code(msc_4388: bool) { let server = MatrixMockServer::new().await; let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await; + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, msc_4388) + .await; debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); let device_authorization_grant = AuthorizationGrant { @@ -968,7 +983,7 @@ mod test { // Create a secure channel on the new client (Bob) and extract the QR code. let client = HttpClient::new(reqwest::Client::new(), Default::default()); - let channel = SecureChannel::login(client, &rendezvous_server.homeserver_url, false) + let channel = SecureChannel::login(client, &rendezvous_server.homeserver_url, msc_4388) .await .expect("Bob should be able to create a secure channel."); let qr_code_data = channel.qr_code_data().clone(); @@ -1060,11 +1075,22 @@ mod test { bob_task.await.expect("Bob's task should finish"); } + #[async_test] + async fn test_grant_login_with_scanned_qr_code_msc_4108() { + test_grant_login_with_scanned_qr_code(false).await; + } + + #[async_test] + async fn test_grant_login_with_scanned_qr_code_msc_4388() { + test_grant_login_with_scanned_qr_code(true).await; + } + #[async_test] async fn test_grant_login_with_scanned_qr_code_with_homeserver_swap() { let server = MatrixMockServer::new().await; let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await; + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await; debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); let device_authorization_grant = AuthorizationGrant { @@ -1194,7 +1220,8 @@ mod test { { let server = MatrixMockServer::new().await; let rendezvous_server = Arc::new( - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await, + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await, ); debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); @@ -1320,7 +1347,8 @@ mod test { async fn test_grant_login_with_scanned_qr_code_unexpected_message_instead_of_login_protocol() { let server = MatrixMockServer::new().await; let rendezvous_server = Arc::new( - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await, + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await, ); debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); @@ -1431,7 +1459,8 @@ mod test { async fn test_grant_login_with_generated_qr_code_device_already_exists() { let server = MatrixMockServer::new().await; let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await; + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await; debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); let device_authorization_grant = AuthorizationGrant { @@ -1561,7 +1590,8 @@ mod test { async fn test_grant_login_with_scanned_qr_code_device_already_exists() { let server = MatrixMockServer::new().await; let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await; + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await; debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); let device_authorization_grant = AuthorizationGrant { @@ -1676,7 +1706,8 @@ mod test { async fn test_grant_login_with_generated_qr_code_device_not_found() { let server = MatrixMockServer::new().await; let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await; + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await; debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); let device_authorization_grant = AuthorizationGrant { @@ -1816,7 +1847,8 @@ mod test { async fn test_grant_login_with_scanned_qr_code_device_not_found() { let server = MatrixMockServer::new().await; let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await; + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await; debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); let device_authorization_grant = AuthorizationGrant { @@ -1938,9 +1970,13 @@ mod test { #[async_test] async fn test_grant_login_with_generated_qr_code_session_expired() { let server = MatrixMockServer::new().await; - let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::from_secs(2)) - .await; + let rendezvous_server = MockedRendezvousServer::new( + server.server(), + "abcdEFG12345", + Duration::from_secs(2), + false, + ) + .await; debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await; @@ -2013,9 +2049,13 @@ mod test { #[async_test] async fn test_grant_login_with_scanned_qr_code_session_expired() { let server = MatrixMockServer::new().await; - let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::from_secs(2)) - .await; + let rendezvous_server = MockedRendezvousServer::new( + server.server(), + "abcdEFG12345", + Duration::from_secs(2), + false, + ) + .await; debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await; @@ -2094,7 +2134,8 @@ mod test { async fn test_grant_login_with_generated_qr_code_login_failure_instead_of_login_protocol() { let server = MatrixMockServer::new().await; let rendezvous_server = Arc::new( - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await, + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await, ); debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); @@ -2217,7 +2258,8 @@ mod test { async fn test_grant_login_with_scanned_qr_code_login_failure_instead_of_login_protocol() { let server = MatrixMockServer::new().await; let rendezvous_server = Arc::new( - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await, + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await, ); debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); @@ -2325,7 +2367,8 @@ mod test { async fn test_grant_login_with_scanned_qr_code_login_failure_instead_of_login_success() { let server = MatrixMockServer::new().await; let rendezvous_server = Arc::new( - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await, + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await, ); debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); @@ -2470,7 +2513,8 @@ mod test { async fn test_grant_login_with_generated_qr_code_login_failure_instead_of_login_success() { let server = MatrixMockServer::new().await; let rendezvous_server = Arc::new( - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await, + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await, ); debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); @@ -2598,7 +2642,8 @@ mod test { async fn test_grant_login_with_generated_qr_code_unexpected_message_instead_of_login_success() { let server = MatrixMockServer::new().await; let rendezvous_server = Arc::new( - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await, + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await, ); debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); @@ -2744,7 +2789,8 @@ mod test { async fn test_grant_login_with_scanned_qr_code_unexpected_message_instead_of_login_success() { let server = MatrixMockServer::new().await; let rendezvous_server = Arc::new( - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await, + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await, ); debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); @@ -2873,7 +2919,8 @@ mod test { async fn test_grant_login_with_generated_qr_code_secure_channel_error() { let server = MatrixMockServer::new().await; let rendezvous_server = Arc::new( - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await, + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await, ); debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); @@ -2998,7 +3045,8 @@ mod test { async fn test_grant_login_with_scanned_qr_code_secure_channel_error() { let server = MatrixMockServer::new().await; let rendezvous_server = Arc::new( - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await, + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await, ); debug!("Set up rendezvous server mock at {}", rendezvous_server.rendezvous_url); diff --git a/crates/matrix-sdk/src/authentication/oauth/qrcode/login.rs b/crates/matrix-sdk/src/authentication/oauth/qrcode/login.rs index d58b39c3f..ddc300dce 100644 --- a/crates/matrix-sdk/src/authentication/oauth/qrcode/login.rs +++ b/crates/matrix-sdk/src/authentication/oauth/qrcode/login.rs @@ -301,6 +301,8 @@ impl<'a> IntoFuture for LoginWithQrCode<'a> { // scanned the QR code, we're certain that the secure channel is // secure, under the assumption that we didn't scan the wrong QR code. // -- MSC4108 Secure channel setup steps 3-5 + trace!("Trying to establish the secure channel"); + let channel = self.establish_secure_channel().await?; trace!("Established the secure channel."); @@ -385,6 +387,8 @@ impl<'a> IntoFuture for LoginWithGeneratedQrCode<'a> { Box::pin(async move { // Establish and verify the secure channel. // -- MSC4108 Secure channel setup all steps + trace!("Trying to establish the secure channel"); + let mut channel = self.establish_secure_channel().await?; trace!("Established the secure channel."); @@ -597,11 +601,11 @@ mod test { alice.send_json(message).await.unwrap(); } - #[async_test] - async fn test_qr_login() { + async fn test_qr_login(msc_4388: bool) { let server = MatrixMockServer::new().await; let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await; + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, msc_4388) + .await; let (sender, receiver) = tokio::sync::oneshot::channel(); let oauth_server = server.oauth(); @@ -622,16 +626,28 @@ mod test { server.mock_query_keys().ok().expect(1).named("query_keys").mount().await; let client = HttpClient::new(reqwest::Client::new(), Default::default()); - let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url, false) + let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url, msc_4388) .await .expect("Alice should be able to create a secure channel."); - assert_let!( - QrCodeIntentData::Msc4108 { - data: Msc4108IntentData::Reciprocate { server_name }, - .. - } = &alice.qr_code_data().intent_data() - ); + assert_eq!(alice.qr_code_data().intent(), QrCodeIntent::Reciprocate); + + let server_name = if msc_4388 { + assert_let!( + QrCodeIntentData::Msc4388 { base_url, .. } = &alice.qr_code_data().intent_data() + ); + + base_url.to_string() + } else { + assert_let!( + QrCodeIntentData::Msc4108 { + data: Msc4108IntentData::Reciprocate { server_name }, + .. + } = &alice.qr_code_data().intent_data() + ); + + server_name.to_owned() + }; let bob = Client::builder() .server_name_or_homeserver_url(server_name) @@ -679,6 +695,16 @@ mod test { assert!(own_identity.is_verified()); } + #[async_test] + async fn test_qr_login_msc_4108() { + test_qr_login(false).await; + } + + #[async_test] + async fn test_qr_login_msc_4388() { + test_qr_login(true).await; + } + async fn grant_login_with_generated_qr( alice: &Client, qr_receiver: tokio::sync::oneshot::Receiver, @@ -762,11 +788,11 @@ mod test { .expect("Alice should be able to send the `m.login.secrets` message to Bob"); } - #[async_test] - async fn test_generated_qr_login() { + async fn test_generated_qr_login(msc_4388: bool) { let server = MatrixMockServer::new().await; let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await; + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, msc_4388) + .await; let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel(); let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel(); @@ -803,18 +829,32 @@ mod test { .expect("Should be able to create a client for Bob"); let secure_channel = - SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url, false) + SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url, msc_4388) .await .expect("Bob should be able to create a secure channel"); - assert_matches!( - secure_channel.qr_code_data().intent_data(), - QrCodeIntentData::Msc4108 { data: Msc4108IntentData::Login, .. } - ); + assert_eq!(secure_channel.qr_code_data().intent(), QrCodeIntent::Login); + + if msc_4388 { + assert_matches!( + secure_channel.qr_code_data().intent_data(), + QrCodeIntentData::Msc4388 { .. } + ); + } else { + assert_matches!( + secure_channel.qr_code_data().intent_data(), + QrCodeIntentData::Msc4108 { data: Msc4108IntentData::Login, .. } + ); + } let registration_data = mock_client_metadata().into(); let bob_oauth = bob.oauth(); - let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate(); + let mut bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate(); + + if msc_4388 { + bob_login.with_msc4388_support(); + } + let mut bob_updates = bob_login.subscribe_to_progress(); let updates_task = spawn(async move { @@ -868,10 +908,20 @@ mod test { } #[async_test] - async fn test_generated_qr_login_with_homeserver_swap() { + async fn test_generated_qr_login_msc_4108() { + test_generated_qr_login(false).await; + } + + #[async_test] + async fn test_generated_qr_login_msc_4388() { + test_generated_qr_login(true).await; + } + + async fn test_generated_qr_login_with_homeserver_swap(msc_4388: bool) { let server = MatrixMockServer::new().await; let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await; + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, msc_4388) + .await; let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel(); let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel(); @@ -912,18 +962,32 @@ mod test { .expect("Should be able to create a client for Bob"); let secure_channel = - SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url, false) + SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url, msc_4388) .await .expect("Bob should be able to create a secure channel"); - assert_matches!( - secure_channel.qr_code_data().intent_data(), - QrCodeIntentData::Msc4108 { data: Msc4108IntentData::Login, .. } - ); + assert_eq!(secure_channel.qr_code_data().intent(), QrCodeIntent::Login); + + if msc_4388 { + assert_matches!( + secure_channel.qr_code_data().intent_data(), + QrCodeIntentData::Msc4388 { .. } + ); + } else { + assert_matches!( + secure_channel.qr_code_data().intent_data(), + QrCodeIntentData::Msc4108 { data: Msc4108IntentData::Login, .. } + ); + } let registration_data = mock_client_metadata().into(); let bob_oauth = bob.oauth(); - let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate(); + let mut bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate(); + + if msc_4388 { + bob_login.with_msc4388_support(); + } + let mut bob_updates = bob_login.subscribe_to_progress(); let updates_task = spawn(async move { @@ -976,9 +1040,20 @@ mod test { assert!(own_identity.is_verified()); } + #[async_test] + async fn test_generated_qr_login_with_homeserver_swap_msc_4108() { + test_generated_qr_login_with_homeserver_swap(false).await; + } + + #[async_test] + async fn test_generated_qr_login_with_homeserver_swap_msc_4388() { + test_generated_qr_login_with_homeserver_swap(true).await; + } + async fn test_failure( token_response: TokenResponse, alice_behavior: AliceBehaviour, + msc_4388: bool, ) -> Result<(), QRCodeLoginError> { let server = MatrixMockServer::new().await; let expiration = match alice_behavior { @@ -986,7 +1061,8 @@ mod test { _ => Duration::MAX, }; let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", expiration).await; + MockedRendezvousServer::new(server.server(), "abcdEFG12345", expiration, msc_4388) + .await; let (sender, receiver) = tokio::sync::oneshot::channel(); let oauth_server = server.oauth(); @@ -1028,16 +1104,30 @@ mod test { server.mock_who_am_i().ok().named("whoami").mount().await; let client = HttpClient::new(reqwest::Client::new(), Default::default()); - let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url, false) + let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url, msc_4388) .await .expect("Alice should be able to create a secure channel."); - assert_let!( - QrCodeIntentData::Msc4108 { - data: Msc4108IntentData::Reciprocate { server_name }, - .. - } = &alice.qr_code_data().intent_data() - ); + assert_eq!(alice.qr_code_data().intent(), QrCodeIntent::Reciprocate); + + let server_name = if msc_4388 { + assert_let!( + QrCodeIntentData::Msc4388 { base_url, .. } = &alice.qr_code_data().intent_data() + ); + + base_url.to_string() + } else { + assert_let!( + QrCodeIntentData::Msc4108 { + data: Msc4108IntentData::Reciprocate { server_name }, + .. + } = &alice.qr_code_data().intent_data() + ); + + server_name.to_owned() + }; + + assert_eq!(alice.qr_code_data().intent(), QrCodeIntent::Reciprocate); let bob = Client::builder() .server_name_or_homeserver_url(server_name) @@ -1082,6 +1172,7 @@ mod test { async fn test_generated_failure( token_response: TokenResponse, alice_behavior: AliceBehaviour, + msc_4388: bool, ) -> Result<(), QRCodeLoginError> { let server = MatrixMockServer::new().await; let expiration = match alice_behavior { @@ -1089,7 +1180,8 @@ mod test { _ => Duration::MAX, }; let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", expiration).await; + MockedRendezvousServer::new(server.server(), "abcdEFG12345", expiration, msc_4388) + .await; let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel(); let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel(); @@ -1148,18 +1240,32 @@ mod test { .expect("Should be able to create a client for Bob"); let secure_channel = - SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url, false) + SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url, msc_4388) .await .expect("Bob should be able to create a secure channel"); - assert_matches!( - secure_channel.qr_code_data().intent_data(), - QrCodeIntentData::Msc4108 { data: Msc4108IntentData::Login, .. } - ); + assert_eq!(secure_channel.qr_code_data().intent(), QrCodeIntent::Login); + + if msc_4388 { + assert_matches!( + secure_channel.qr_code_data().intent_data(), + QrCodeIntentData::Msc4388 { .. } + ); + } else { + assert_matches!( + secure_channel.qr_code_data().intent_data(), + QrCodeIntentData::Msc4108 { data: Msc4108IntentData::Login, .. } + ); + } let registration_data = mock_client_metadata().into(); let bob_oauth = bob.oauth(); - let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate(); + let mut bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate(); + + if msc_4388 { + bob_login.with_msc4388_support(); + } + let mut bob_updates = bob_login.subscribe_to_progress(); let _updates_task = spawn(async move { @@ -1200,9 +1306,9 @@ mod test { bob_login.await } - #[async_test] - async fn test_qr_login_refused_access_token() { - let result = test_failure(TokenResponse::AccessDenied, AliceBehaviour::HappyPath).await; + async fn test_qr_login_refused_access_token(msc_4388: bool) { + let result = + test_failure(TokenResponse::AccessDenied, AliceBehaviour::HappyPath, msc_4388).await; assert_let!(Err(QRCodeLoginError::OAuth(e)) = result); assert_eq!( @@ -1213,9 +1319,22 @@ mod test { } #[async_test] - async fn test_generated_qr_login_refused_access_token() { - let result = - test_generated_failure(TokenResponse::AccessDenied, AliceBehaviour::HappyPath).await; + async fn test_qr_login_refused_access_token_msc_4108() { + test_qr_login_refused_access_token(false).await + } + + #[async_test] + async fn test_qr_login_refused_access_token_msc_4388() { + test_qr_login_refused_access_token(true).await + } + + async fn test_generated_qr_login_refused_access_token(msc_4388: bool) { + let result = test_generated_failure( + TokenResponse::AccessDenied, + AliceBehaviour::HappyPath, + msc_4388, + ) + .await; assert_let!(Err(QRCodeLoginError::OAuth(e)) = result); assert_eq!( @@ -1226,8 +1345,18 @@ mod test { } #[async_test] - async fn test_qr_login_expired_token() { - let result = test_failure(TokenResponse::ExpiredToken, AliceBehaviour::HappyPath).await; + async fn test_generated_qr_login_refused_access_token_msc_4108() { + test_generated_qr_login_refused_access_token(false).await; + } + + #[async_test] + async fn test_generated_qr_login_refused_access_token_msc_4388() { + test_generated_qr_login_refused_access_token(true).await; + } + + async fn test_qr_login_expired_token(msc_4388: bool) { + let result = + test_failure(TokenResponse::ExpiredToken, AliceBehaviour::HappyPath, msc_4388).await; assert_let!(Err(QRCodeLoginError::OAuth(e)) = result); assert_eq!( @@ -1238,9 +1367,22 @@ mod test { } #[async_test] - async fn test_generated_qr_login_expired_token() { - let result = - test_generated_failure(TokenResponse::ExpiredToken, AliceBehaviour::HappyPath).await; + async fn test_qr_login_expired_token_msc_4108() { + test_qr_login_expired_token(false).await; + } + + #[async_test] + async fn test_qr_login_expired_token_msc_4388() { + test_qr_login_expired_token(true).await; + } + + async fn test_generated_qr_login_expired_token(msc_4388: bool) { + let result = test_generated_failure( + TokenResponse::ExpiredToken, + AliceBehaviour::HappyPath, + msc_4388, + ) + .await; assert_let!(Err(QRCodeLoginError::OAuth(e)) = result); assert_eq!( @@ -1251,8 +1393,18 @@ mod test { } #[async_test] - async fn test_qr_login_declined_protocol() { - let result = test_failure(TokenResponse::Ok, AliceBehaviour::DeclinedProtocol).await; + async fn test_generated_qr_login_expired_token_msc_4108() { + test_generated_qr_login_expired_token(false).await; + } + + #[async_test] + async fn test_generated_qr_login_expired_token_msc_4388() { + test_generated_qr_login_expired_token(true).await; + } + + async fn test_qr_login_declined_protocol(msc_4388: bool) { + let result = + test_failure(TokenResponse::Ok, AliceBehaviour::DeclinedProtocol, msc_4388).await; assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result); assert_eq!( @@ -1263,9 +1415,19 @@ mod test { } #[async_test] - async fn test_generated_qr_login_declined_protocol() { + async fn test_qr_login_declined_protocol_msc_4108() { + test_qr_login_declined_protocol(false).await; + } + + #[async_test] + async fn test_qr_login_declined_protocol_msc_4388() { + test_qr_login_declined_protocol(true).await; + } + + async fn test_generated_qr_login_declined_protocol(msc_4388: bool) { let result = - test_generated_failure(TokenResponse::Ok, AliceBehaviour::DeclinedProtocol).await; + test_generated_failure(TokenResponse::Ok, AliceBehaviour::DeclinedProtocol, msc_4388) + .await; assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result); assert_eq!( @@ -1276,37 +1438,57 @@ mod test { } #[async_test] - async fn test_qr_login_unexpected_message() { - let result = test_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessage).await; + async fn test_generated_qr_login_declined_protocol_msc_4108() { + test_generated_qr_login_declined_protocol(false).await; + } + + #[async_test] + async fn test_generated_qr_login_declined_protocol_msc_4388() { + test_generated_qr_login_declined_protocol(true).await; + } + + async fn test_qr_login_unexpected_message(msc_4388: bool) { + let result = + test_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessage, msc_4388).await; assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result); assert_eq!(expected, "m.login.protocol_accepted"); } #[async_test] - async fn test_generated_qr_login_unexpected_message() { - let result = - test_generated_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessage).await; - - assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result); - assert_eq!(expected, "m.login.protocol_accepted"); + async fn test_qr_login_unexpected_message_msc_4108() { + test_qr_login_unexpected_message(false).await; } #[async_test] - async fn test_qr_login_unexpected_message_instead_of_secrets() { + async fn test_qr_login_unexpected_message_msc_4388() { + test_qr_login_unexpected_message(true).await; + } + + async fn test_generated_qr_login_unexpected_message(msc_4388: bool) { let result = - test_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessageInsteadOfSecrets) + test_generated_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessage, msc_4388) .await; assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result); - assert_eq!(expected, "m.login.secrets"); + assert_eq!(expected, "m.login.protocol_accepted"); } #[async_test] - async fn test_generated_qr_login_unexpected_message_instead_of_secrets() { - let result = test_generated_failure( + async fn test_generated_qr_login_unexpected_message_msc_4108() { + test_generated_qr_login_unexpected_message(false).await; + } + + #[async_test] + async fn test_generated_qr_login_unexpected_message_msc_4388() { + test_generated_qr_login_unexpected_message(true).await; + } + + async fn test_qr_login_unexpected_message_instead_of_secrets(msc_4388: bool) { + let result = test_failure( TokenResponse::Ok, AliceBehaviour::UnexpectedMessageInsteadOfSecrets, + msc_4388, ) .await; @@ -1315,41 +1497,114 @@ mod test { } #[async_test] - async fn test_qr_login_refuse_secrets() { - let result = test_failure(TokenResponse::Ok, AliceBehaviour::RefuseSecrets).await; + async fn test_qr_login_unexpected_message_instead_of_secrets_msc_4108() { + test_qr_login_unexpected_message_instead_of_secrets(false).await; + } + + #[async_test] + async fn test_qr_login_unexpected_message_instead_of_secrets_msc_4388() { + test_qr_login_unexpected_message_instead_of_secrets(true).await; + } + + async fn test_generated_qr_login_unexpected_message_instead_of_secrets(msc_4388: bool) { + let result = test_generated_failure( + TokenResponse::Ok, + AliceBehaviour::UnexpectedMessageInsteadOfSecrets, + msc_4388, + ) + .await; + + assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result); + assert_eq!(expected, "m.login.secrets"); + } + + #[async_test] + async fn test_generated_qr_login_unexpected_message_instead_of_secrets_msc_4108() { + test_generated_qr_login_unexpected_message_instead_of_secrets(false).await; + } + + #[async_test] + async fn test_generated_qr_login_unexpected_message_instead_of_secrets_msc_4388() { + test_generated_qr_login_unexpected_message_instead_of_secrets(true).await; + } + + async fn test_qr_login_refuse_secrets(msc_4388: bool) { + let result = test_failure(TokenResponse::Ok, AliceBehaviour::RefuseSecrets, msc_4388).await; assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result); assert_eq!(reason, LoginFailureReason::DeviceNotFound); } #[async_test] - async fn test_generated_qr_login_refuse_secrets() { - let result = test_generated_failure(TokenResponse::Ok, AliceBehaviour::RefuseSecrets).await; - - assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result); - assert_eq!(reason, LoginFailureReason::DeviceNotFound); + async fn test_qr_login_refuse_secrets_msc_4108() { + test_qr_login_refuse_secrets(false).await } #[async_test] - async fn test_qr_login_session_expired() { - let result = test_failure(TokenResponse::Ok, AliceBehaviour::LetSessionExpire).await; - - assert_matches!(result, Err(QRCodeLoginError::NotFound)); + async fn test_qr_login_refuse_secrets_msc_4388() { + test_qr_login_refuse_secrets(true).await } - #[async_test] - async fn test_generated_qr_login_session_expired() { + async fn test_generated_qr_login_refuse_secrets(msc_4388: bool) { let result = - test_generated_failure(TokenResponse::Ok, AliceBehaviour::LetSessionExpire).await; + test_generated_failure(TokenResponse::Ok, AliceBehaviour::RefuseSecrets, msc_4388) + .await; + + assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result); + assert_eq!(reason, LoginFailureReason::DeviceNotFound); + } + + #[async_test] + async fn test_generated_qr_login_refuse_secrets_msc_4108() { + test_generated_qr_login_refuse_secrets(false).await; + } + + #[async_test] + async fn test_generated_qr_login_refuse_secrets_msc_4388() { + test_generated_qr_login_refuse_secrets(true).await; + } + + async fn test_qr_login_session_expired(msc_4388: bool) { + let result = + test_failure(TokenResponse::Ok, AliceBehaviour::LetSessionExpire, msc_4388).await; assert_matches!(result, Err(QRCodeLoginError::NotFound)); } + #[async_test] + async fn test_qr_login_session_expired_msc_4108() { + test_qr_login_session_expired(false).await; + } + + #[async_test] + async fn test_qr_login_session_expired_msc_4388() { + test_qr_login_session_expired(true).await; + } + + async fn test_generated_qr_login_session_expired(msc_4388: bool) { + let result = + test_generated_failure(TokenResponse::Ok, AliceBehaviour::LetSessionExpire, msc_4388) + .await; + + assert_matches!(result, Err(QRCodeLoginError::NotFound)); + } + + #[async_test] + async fn test_generated_qr_login_session_expired_msc_4108() { + test_generated_qr_login_session_expired(false).await; + } + + #[async_test] + async fn test_generated_qr_login_session_expired_msc_4388() { + test_generated_qr_login_session_expired(true).await; + } + #[async_test] async fn test_device_authorization_endpoint_missing() { let server = MatrixMockServer::new().await; let rendezvous_server = - MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await; + MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX, false) + .await; let (sender, receiver) = tokio::sync::oneshot::channel(); let oauth_server = server.oauth(); diff --git a/crates/matrix-sdk/src/authentication/oauth/qrcode/rendezvous_channel/msc_4388.rs b/crates/matrix-sdk/src/authentication/oauth/qrcode/rendezvous_channel/msc_4388.rs index 987b367b3..89b236b96 100644 --- a/crates/matrix-sdk/src/authentication/oauth/qrcode/rendezvous_channel/msc_4388.rs +++ b/crates/matrix-sdk/src/authentication/oauth/qrcode/rendezvous_channel/msc_4388.rs @@ -343,7 +343,7 @@ mod test { let client = HttpClient::new(reqwest::Client::new(), RequestConfig::short_retry()); let InboundChannelCreationResult { channel: bob, initial_message: _ } = - Channel::create_inbound(client, &base_url, &rendezvous_id.to_owned()).await.expect( + Channel::create_inbound(client, &base_url, rendezvous_id).await.expect( "We should be able to create a rendezvous channel from a received message", ); diff --git a/crates/matrix-sdk/src/authentication/oauth/qrcode/secure_channel/mod.rs b/crates/matrix-sdk/src/authentication/oauth/qrcode/secure_channel/mod.rs index 7f6b29914..124a7794c 100644 --- a/crates/matrix-sdk/src/authentication/oauth/qrcode/secure_channel/mod.rs +++ b/crates/matrix-sdk/src/authentication/oauth/qrcode/secure_channel/mod.rs @@ -363,8 +363,10 @@ pub(super) mod test { use matrix_sdk_common::executor::spawn; use matrix_sdk_test::async_test; use ruma::time::Instant; + use serde::Deserialize; use serde_json::json; use similar_asserts::assert_eq; + use tracing::trace; use url::Url; use vodozemac::hpke::DigitMode; use wiremock::{ @@ -389,7 +391,24 @@ pub(super) mod test { } impl MockedRendezvousServer { - pub async fn new(server: &MockServer, location: &str, expiration: Duration) -> Self { + pub async fn new( + server: &MockServer, + location: &str, + expiration: Duration, + msc_4388: bool, + ) -> Self { + if msc_4388 { + Self::new_msc4388(server, expiration).await + } else { + Self::new_msc4108(server, location, expiration).await + } + } + + pub async fn new_msc4108( + server: &MockServer, + location: &str, + expiration: Duration, + ) -> Self { let content: Arc>> = Mutex::default().into(); let created: Arc>> = Mutex::default().into(); let etag = Arc::new(AtomicU8::new(0)); @@ -506,16 +525,140 @@ pub(super) mod test { rendezvous_url, } } + + pub async fn new_msc4388(server: &MockServer, expiration: Duration) -> Self { + #[derive(Debug, Deserialize)] + struct PutContent { + #[allow(dead_code)] + sequence_token: String, + data: String, + } + + const RENDEZVOUS_ID: &str = "abcdEFG12345"; + + let content: Arc>> = Mutex::default().into(); + let created: Arc>> = Mutex::default().into(); + let sequence_token = Arc::new(AtomicU8::new(0)); + + let homeserver_url = Url::parse(&server.uri()) + .expect("We should be able to parse the example homeserver"); + + let rendezvous_url = homeserver_url + .join(RENDEZVOUS_ID) + .expect("We should be able to create a rendezvous URL"); + + let post_guard = server + .register_as_scoped( + Mock::given(method("POST")) + .and(path("/_matrix/client/unstable/io.element.msc4388/rendezvous")) + .respond_with({ + *created.lock().unwrap() = Some(Instant::now()); + + trace!("Creating a new rendezvous channel ID: {RENDEZVOUS_ID}"); + + ResponseTemplate::new(200).set_body_json(json!({ + "id": RENDEZVOUS_ID, + "sequence_token": "0", + "expires_in_ms": 100_000, + })) + }), + ) + .await; + + let put_guard = server + .register_as_scoped( + Mock::given(method("PUT")) + .and(path(format!( + "/_matrix/client/unstable/io.element.msc4388/rendezvous/{RENDEZVOUS_ID}" + ))) + .respond_with({ + let content = content.clone(); + let created = created.clone(); + let sequence_token = sequence_token.clone(); + + + move |request: &wiremock::Request| { + // Fail the request if the session has expired. + if created.lock().unwrap().unwrap().elapsed() > expiration { + return ResponseTemplate::new(404).set_body_json(json!({ + "errcode": "M_NOT_FOUND", + "error": "This rendezvous session does not exist.", + })); + } + + let request_content: PutContent = request.body_json().unwrap(); + *content.lock().unwrap() = Some(request_content.data); + + let prev_token = + sequence_token.fetch_add(1, Ordering::SeqCst); + + trace!("Putting new content into the rendezvous channel ID: {RENDEZVOUS_ID}"); + + ResponseTemplate::new(200).set_body_json(json!({ + "sequence_token": (prev_token + 1 ).to_string(), + + })) + } + }), + ) + .await; + + let get_guard = server + .register_as_scoped( + Mock::given(method("GET")) + .and(path(format!( + "/_matrix/client/unstable/io.element.msc4388/rendezvous/{RENDEZVOUS_ID}" + ))) + .respond_with({ + let content = content.clone(); + let created = created.clone(); + let sequence_token = sequence_token.clone(); + + move |_: &wiremock::Request| { + // Fail the request if the session has expired. + if created.lock().unwrap().unwrap().elapsed() > expiration { + return ResponseTemplate::new(404).set_body_json(json!({ + "errcode": "M_NOT_FOUND", + "error": "This rendezvous session does not exist.", + })); + } + + let content = content.lock().unwrap(); + let current_sequence_token = sequence_token.load(Ordering::SeqCst); + + let content = content.clone(); + + ResponseTemplate::new(200).set_body_json(json!({ + "data": content.unwrap_or_default(), + "sequence_token": current_sequence_token.to_string(), + "expires_in_ms": 100_000, + })) + } + }), + ) + .await; + + Self { + expiration, + content, + created, + etag: sequence_token, + post_guard, + put_guard, + get_guard, + homeserver_url, + rendezvous_url, + } + } } - #[async_test] - async fn test_creation() { + async fn test_creation(msc_4388: bool) { let server = MockServer::start().await; let rendezvous_server = - MockedRendezvousServer::new(&server, "abcdEFG12345", Duration::MAX).await; + MockedRendezvousServer::new(&server, "abcdEFG12345", Duration::MAX, msc_4388).await; let client = HttpClient::new(reqwest::Client::new(), Default::default()); - let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url, false) + let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url, msc_4388) .await .expect("Alice should be able to create a secure channel."); @@ -549,4 +692,14 @@ pub(super) mod test { assert_eq!(bob.channel.rendezvous_info(), alice.channel.rendezvous_info()); } + + #[async_test] + async fn test_creation_msc4388() { + test_creation(true).await; + } + + #[async_test] + async fn test_creation_msc4108() { + test_creation(false).await; + } }